diff --git a/Directory.Packages.props b/Directory.Packages.props index 05e11c0..221d1e3 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -8,11 +8,11 @@ - + - \ No newline at end of file + diff --git a/Injectio.slnx b/Injectio.slnx index 6b7aea5..babb955 100644 --- a/Injectio.slnx +++ b/Injectio.slnx @@ -8,10 +8,11 @@ + - + diff --git a/README.md b/README.md index 35978a6..b1970f5 100644 --- a/README.md +++ b/README.md @@ -17,21 +17,20 @@ Source generator that helps register attribute marked services in the dependency - Module method registration - Duplicate Strategy - Skip,Replace,Append - Registration Strategy - Self, Implemented Interfaces, Self With Interfaces +- Decorator registration (`RegisterDecorator`) — no runtime dependencies ### Usage +#### Requirements + +Requires Roslyn 4.14 or later. This means Visual Studio 2022 version 17.14+ or Visual Studio 2026+, a current Rider release, and the .NET 9.0.300 SDK or newer. Older toolchains will not load the analyzer. + #### Add package Add the nuget package project to your projects. `dotnet add package Injectio` -Prevent dependances from including Injectio - -```xml - -``` - ### Registration Attributes Place registration attribute on class. The class will be discovered and registered. @@ -40,6 +39,7 @@ Place registration attribute on class. The class will be discovered and registe - `[RegisterScoped]` Marks the class as a scoped service - `[RegisterTransient]` Marks the class as a transient service - `[RegisterServices]` Marks the method to be called to register services +- `[RegisterDecorator]` Marks the class as a decorator around an existing service #### Attribute Properties @@ -217,6 +217,119 @@ public class ServiceFactoryKeyed : IServiceKeyed } ``` +#### Decorators + +Use the `RegisterDecorator` attribute to wrap an existing service registration without adding +any runtime dependencies. The generator emits all decoration helpers directly into the +consumer assembly. + +Decorators inherit the lifetime of the service they decorate. Apply multiple decorators by +ordering them with the `Order` property — lower values are innermost (applied first), higher +values are outermost (applied last). + +```c# +public interface IService { } + +[RegisterSingleton] +public class Service : IService { } + +[RegisterDecorator(Order = 1)] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } +} + +[RegisterDecorator(Order = 2)] +public class CachingDecorator : IService +{ + public CachingDecorator(IService inner) { } +} +``` + +Resolution order for the sample above: `CachingDecorator → LoggingDecorator → Service`. + +##### Decorator Attribute Properties + +| Property | Description | +|--------------------|------------------------------------------------------------------------------------------------| +| ServiceType | The type of service to decorate. Required unless the generic attribute form is used. | +| ImplementationType | The decorator type. If not set, the class the attribute is on will be used. | +| ServiceKey | Decorate a specific keyed registration. Requires .NET 8+ Microsoft.Extensions.DependencyInjection. | +| AnyKey | When `true`, decorate every keyed registration of `ServiceType` regardless of its key. | +| Factory | Name of a static factory method that builds the decorator. | +| Order | Ordering within the decoration chain. Lower = innermost. | +| Tags | Comma/semicolon-delimited list of registration tags. | + +##### Keyed decoration + +Decorate a single keyed variant, or use `AnyKey` to decorate them all: + +```c# +[RegisterSingleton(ServiceKey = "alpha")] +public class AlphaService : IService { } + +[RegisterDecorator(AnyKey = true)] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } +} +``` + +##### Factory-built decorators + +Provide a static factory on the decorator class for complex construction: + +```c# +[RegisterDecorator(Factory = nameof(Create))] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } + + public static IService Create(IServiceProvider serviceProvider, IService inner) + => new LoggingDecorator(inner); +} +``` + +For keyed decorators the factory takes an additional `object?` parameter for the key: + +```c# +public static IService Create(IServiceProvider serviceProvider, object? serviceKey, IService inner) + => new LoggingDecorator(inner); +``` + +##### Open-generic decoration + +Open-generic decorators apply to every closed registration of the matching service type. +The generator supports decorating closed-generic registrations with an open-generic decorator +class; purely open-generic implementation registrations (e.g. `(IRepo<>, Repo<>)`) are not +decorated at runtime due to a Microsoft.Extensions.DependencyInjection limitation on factory +registrations for open generic service types. + +```c# +public interface IRepo { } + +[RegisterSingleton, StringRepo>] +public class StringRepo : IRepo { } + +[RegisterDecorator(ServiceType = typeof(IRepo<>))] +public class LoggingRepo : IRepo +{ + public LoggingRepo(IRepo inner) { } +} +``` + +##### Tags + +Decorators support the same tag-filtering as registrations: + +```c# +[RegisterDecorator(Tags = "FrontEnd")] +public class FrontEndLoggingDecorator : IService +{ + public FrontEndLoggingDecorator(IService inner) { } +} +``` + #### Register Method When the service registration is complex, use the `RegisterServices` attribute on a method that has a parameter of `IServiceCollection` or `ServiceCollection` diff --git a/src/Injectio.Attributes/DuplicateStrategy.cs b/src/Injectio.Attributes/DuplicateStrategy.cs deleted file mode 100644 index a383547..0000000 --- a/src/Injectio.Attributes/DuplicateStrategy.cs +++ /dev/null @@ -1,20 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Service registration duplicate strategies -/// -public enum DuplicateStrategy -{ - /// - /// Skips registrations for services that already exists. - /// - Skip = 0, - /// - /// Replaces existing service registrations. - /// - Replace = 1, - /// - /// Appends a new registration for existing services. - /// - Append = 2 -} diff --git a/src/Injectio.Attributes/Injectio.Attributes.csproj b/src/Injectio.Attributes/Injectio.Attributes.csproj deleted file mode 100644 index b39d861..0000000 --- a/src/Injectio.Attributes/Injectio.Attributes.csproj +++ /dev/null @@ -1,17 +0,0 @@ - - - - netstandard2.0;net8.0;net9.0;net10.0 - latest - Injectio - - - - - - - - - - - diff --git a/src/Injectio.Attributes/RegisterAttribute.cs b/src/Injectio.Attributes/RegisterAttribute.cs deleted file mode 100644 index d590536..0000000 --- a/src/Injectio.Attributes/RegisterAttribute.cs +++ /dev/null @@ -1,71 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Attribute to indicate the target class should be register for dependency injection -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public abstract class RegisterAttribute : Attribute -{ - /// - /// Initializes a new instance of the class. - /// - protected RegisterAttribute() - { - Duplicate = DuplicateStrategy.Skip; - Registration = RegistrationStrategy.SelfWithInterfaces; - } - - - /// - /// The that implements the service. If not set, the class the interface is on will be used. - /// - /// - public Type? ImplementationType { get; set; } - - /// - /// The of the service - /// - /// - public Type? ServiceType { get; set; } - - /// - /// Gets or sets the key of the service. - /// - /// The service key. - /// - public object? ServiceKey { get; set; } - - /// - /// Name of a factory method to create new instances of the service implementation - /// - /// - /// The method signature must be Func{IServiceProvider, object} and must be static - /// - /// - public string? Factory { get; set; } - - /// - /// Gets or sets the duplicate. - /// - /// - /// The duplicate. - /// - public DuplicateStrategy Duplicate { get; set; } - - /// - /// Gets or sets the registration. - /// - /// - /// The registration. - /// - public RegistrationStrategy Registration { get; set; } - - /// - /// Gets or sets the comma delimited list of service registration tags. - /// - /// - /// The comma delimited list of service registration tags. - /// - public string? Tags { get; set; } -} diff --git a/src/Injectio.Attributes/RegisterScopedAttribute.cs b/src/Injectio.Attributes/RegisterScopedAttribute.cs deleted file mode 100644 index 0f3cb00..0000000 --- a/src/Injectio.Attributes/RegisterScopedAttribute.cs +++ /dev/null @@ -1,69 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Attribute to indicate the target class should be register for dependency injection as a scoped service -/// -/// Register the class as an implementation for IService -/// -/// [RegisterScoped] -/// public class ScopedService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterScopedAttribute : RegisterAttribute -{ -} - -#if NET7_0_OR_GREATER -/// -/// Attribute to indicate the target class should be register for dependency injection as a scoped service -/// -/// The type of the service to add. -/// Register the class as an implementation for IService -/// -/// [RegisterScoped<IService>] -/// public class ScopedService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterScopedAttribute : RegisterScopedAttribute - where TService : class -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterScopedAttribute() - { - ServiceType = typeof(TService); - } -} - -/// -/// Attribute to indicate the target class should be register for dependency injection as a scoped service -/// -/// The type of the service to add. -/// The type of the implementation to use. -/// Register the ScopedService class as an implementation for IService -/// -/// [RegisterScoped<IService, ScopedService>] -/// public class ScopedService: IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterScopedAttribute : RegisterScopedAttribute - where TService : class - where TImplementation : class, TService -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterScopedAttribute() - { - ServiceType = typeof(TService); - ImplementationType = typeof(TImplementation); - } -} -#endif diff --git a/src/Injectio.Attributes/RegisterServicesAttribute.cs b/src/Injectio.Attributes/RegisterServicesAttribute.cs deleted file mode 100644 index b2367bf..0000000 --- a/src/Injectio.Attributes/RegisterServicesAttribute.cs +++ /dev/null @@ -1,21 +0,0 @@ -namespace Injectio.Attributes; - -/// Attribute to indicate the method should be called to register services -/// use the RegisterServices attribute -/// -/// public class RegistrationModule -/// { -/// [RegisterServices] -/// public static void Register(IServiceCollection services) -/// { -/// services.TryAddTransient<IModuleService, ModuleService>(); -/// } -/// } -/// -/// -[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterServicesAttribute : Attribute -{ - -} diff --git a/src/Injectio.Attributes/RegisterSingletonAttribute.cs b/src/Injectio.Attributes/RegisterSingletonAttribute.cs deleted file mode 100644 index 806ce7a..0000000 --- a/src/Injectio.Attributes/RegisterSingletonAttribute.cs +++ /dev/null @@ -1,69 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Attribute to indicate the target class should be register for dependency injection as a singleton service -/// -/// Register the class as an implementation for IService -/// -/// [RegisterSingleton] -/// public class SingletonService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterSingletonAttribute : RegisterAttribute -{ -} - -#if NET7_0_OR_GREATER -/// -/// Attribute to indicate the target class should be register for dependency injection as a singleton service -/// -/// The type of the service to add. -/// Register the class as an implementation for IService -/// -/// [RegisterSingleton<IService>] -/// public class SingletonService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterSingletonAttribute : RegisterSingletonAttribute - where TService : class -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterSingletonAttribute() - { - ServiceType = typeof(TService); - } -} - -/// -/// Attribute to indicate the target class should be register for dependency injection as a singleton service -/// -/// The type of the service to add. -/// The type of the implementation to use. -/// Register the SingletonService class as an implementation for IService -/// -/// [RegisterSingleton<IService, SingletonService>] -/// public class SingletonService: IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterSingletonAttribute : RegisterSingletonAttribute - where TService : class - where TImplementation : class, TService -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterSingletonAttribute() - { - ServiceType = typeof(TService); - ImplementationType = typeof(TImplementation); - } -} -#endif diff --git a/src/Injectio.Attributes/RegisterTransientAttribute.cs b/src/Injectio.Attributes/RegisterTransientAttribute.cs deleted file mode 100644 index c3564f6..0000000 --- a/src/Injectio.Attributes/RegisterTransientAttribute.cs +++ /dev/null @@ -1,69 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Attribute to indicate the target class should be register for dependency injection as a transient service -/// -/// Register the class as an implementation for IService -/// -/// [RegisterTransient] -/// public class TransientService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterTransientAttribute : RegisterAttribute -{ -} - -#if NET7_0_OR_GREATER -/// -/// Attribute to indicate the target class should be register for dependency injection as a transient service -/// -/// The type of the service to add. -/// Register the class as an implementation for IService -/// -/// [RegisterTransient<IService>] -/// public class TransientService : IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterTransientAttribute : RegisterTransientAttribute - where TService : class -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterTransientAttribute() - { - ServiceType = typeof(TService); - } -} - -/// -/// Attribute to indicate the target class should be register for dependency injection as a transient service -/// -/// The type of the service to add. -/// The type of the implementation to use. -/// Register the TransientService class as an implementation for IService -/// -/// [RegisterTransient<IService, TransientService>] -/// public class TransientService: IService { } -/// -/// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] -[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] -public class RegisterTransientAttribute : RegisterTransientAttribute - where TService : class - where TImplementation : class, TService -{ - /// - /// Initializes a new instance of the class. - /// - public RegisterTransientAttribute() - { - ServiceType = typeof(TService); - ImplementationType = typeof(TImplementation); - } -} -#endif diff --git a/src/Injectio.Attributes/RegistrationStrategy.cs b/src/Injectio.Attributes/RegistrationStrategy.cs deleted file mode 100644 index 3355135..0000000 --- a/src/Injectio.Attributes/RegistrationStrategy.cs +++ /dev/null @@ -1,26 +0,0 @@ -namespace Injectio.Attributes; - -/// -/// Register service type strategies -/// -public enum RegistrationStrategy -{ - /// - /// Registers each matching concrete type as itself - /// - Self = 0, - /// - /// Registers each matching concrete type as all of its implemented interfaces. - /// - ImplementedInterfaces = 1, - /// - /// Registers each matching concrete type as all of its implemented interfaces and itself - /// - SelfWithInterfaces = 2, - /// - /// Registers each matching concrete type as all of its implemented interfaces and itself. - /// For the interfaces a proxy-factory resolves the service from its type-name, so only one instance is created per lifetime - /// - /// For open-generic registrations, this behaves like - SelfWithProxyFactory = 3 -} diff --git a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md index 65956d6..9ecbdb3 100644 --- a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md +++ b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md @@ -14,3 +14,10 @@ INJ0006 | Usage | Warning | Factory method has invalid signature INJ0007 | Usage | Warning | Implementation does not implement service type INJ0008 | Usage | Warning | Implementation type is abstract INJ0009 | Usage | Warning | RegisterServices on non-static method in abstract class +INJ0010 | Usage | Warning | Decorator does not implement service type +INJ0011 | Usage | Warning | Decorator is missing service type +INJ0012 | Usage | Warning | Decorator has no constructor accepting the inner service +INJ0013 | Usage | Warning | Decorator factory method not found +INJ0014 | Usage | Warning | Decorator factory method has invalid signature +INJ0015 | Usage | Warning | Keyed decoration is not supported for open-generic services +INJ0016 | Usage | Warning | Decorator target service is not registered in this compilation diff --git a/src/Injectio.Generators/DiagnosticDescriptors.cs b/src/Injectio.Generators/DiagnosticDescriptors.cs index 01c7157..f31812b 100644 --- a/src/Injectio.Generators/DiagnosticDescriptors.cs +++ b/src/Injectio.Generators/DiagnosticDescriptors.cs @@ -86,4 +86,69 @@ public static class DiagnosticDescriptors defaultSeverity: DiagnosticSeverity.Warning, isEnabledByDefault: true ); + + public static readonly DiagnosticDescriptor DecoratorDoesNotImplementService = new( + id: "INJ0010", + title: "Decorator does not implement service type", + messageFormat: "Decorator '{0}' does not implement or inherit from service type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorMissingServiceType = new( + id: "INJ0011", + title: "Decorator is missing service type", + messageFormat: "Decorator '{0}' must specify a ServiceType either via the generic attribute or the ServiceType property", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorMissingInnerConstructor = new( + id: "INJ0012", + title: "Decorator has no constructor accepting the inner service", + messageFormat: "Decorator '{0}' must expose a public constructor whose first parameter is of type '{1}' (or use Factory)", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorFactoryNotFound = new( + id: "INJ0013", + title: "Decorator factory method not found", + messageFormat: "Decorator factory method '{0}' was not found on type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorFactoryInvalidSignature = new( + id: "INJ0014", + title: "Decorator factory method has invalid signature", + messageFormat: "Decorator factory method '{0}' on type '{1}' must be static and accept (IServiceProvider, TService) for non-keyed or (IServiceProvider, object?, TService) for keyed decorators", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorOpenGenericKeyed = new( + id: "INJ0015", + title: "Keyed decoration is not supported for open-generic services", + messageFormat: "Decorator '{0}' targets open-generic service '{1}' and cannot be used with ServiceKey or AnyKey", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + customTags: [WellKnownDiagnosticTags.CompilationEnd] + ); + + public static readonly DiagnosticDescriptor DecoratorTargetNotRegistered = new( + id: "INJ0016", + title: "Decorator target service is not registered in this compilation", + messageFormat: "Decorator '{0}' targets service '{1}' but no matching registration was found; decoration will be skipped at runtime if the service is not registered elsewhere", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + customTags: [WellKnownDiagnosticTags.CompilationEnd] + ); } diff --git a/src/Injectio.Generators/EquatableArray.cs b/src/Injectio.Generators/EquatableArray.cs deleted file mode 100644 index 2df341e..0000000 --- a/src/Injectio.Generators/EquatableArray.cs +++ /dev/null @@ -1,60 +0,0 @@ -using System.Collections; -using System.Diagnostics.CodeAnalysis; - -namespace Injectio.Generators; - -[ExcludeFromCodeCoverage] -public readonly struct EquatableArray : IEquatable>, IEnumerable - where T : IEquatable -{ - public static readonly EquatableArray Empty = new(); - - - public EquatableArray() : this([]) { } - - public EquatableArray(T[] array) => Array = array ?? []; - - public EquatableArray(IEnumerable items) => Array = items.ToArray() ?? []; - - - public T[] Array { get; } - - public int Count => Array.Length; - - - public ReadOnlySpan AsSpan() => Array.AsSpan(); - - public T[] AsArray() => Array; - - - public static bool operator ==(EquatableArray left, EquatableArray right) => left.Equals(right); - - public static bool operator !=(EquatableArray left, EquatableArray right) => !left.Equals(right); - - public bool Equals(EquatableArray array) => Array.AsSpan().SequenceEqual(array.AsSpan()); - - public override bool Equals(object? obj) => obj is EquatableArray array && Equals(this, array); - - public override int GetHashCode() - { - if (Array is not T[] array) - return 0; - - var hashCode = 16777619; - - for (int i = 0; i < array.Length; i++) - hashCode = unchecked((hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(array[i])); - - return hashCode; - } - - - IEnumerator IEnumerable.GetEnumerator() => (Array as IEnumerable).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => Array.GetEnumerator(); - - - public static implicit operator EquatableArray(T[] array) => new(array); - - public static implicit operator EquatableArray(List items) => new(items); -} diff --git a/src/Injectio.Generators/Extensions/StringExtensions.cs b/src/Injectio.Generators/Extensions/StringExtensions.cs index 996d285..414a9cb 100644 --- a/src/Injectio.Generators/Extensions/StringExtensions.cs +++ b/src/Injectio.Generators/Extensions/StringExtensions.cs @@ -1,6 +1,8 @@ using System.Diagnostics.CodeAnalysis; using System.Text.RegularExpressions; +using Injectio.Generators.Infrastructure; + using Microsoft.CodeAnalysis; namespace Injectio.Generators.Extensions; diff --git a/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs b/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs new file mode 100644 index 0000000..8dae3d6 --- /dev/null +++ b/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs @@ -0,0 +1,14 @@ +#pragma warning disable IDE0130 // Namespace does not match folder structure + +using System.ComponentModel; + +namespace System.Runtime.CompilerServices; + +[EditorBrowsable(EditorBrowsableState.Never)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Interface, Inherited = false)] +internal sealed class CollectionBuilderAttribute(Type builderType, string methodName) : Attribute +{ + public Type BuilderType { get; } = builderType; + + public string MethodName { get; } = methodName; +} diff --git a/src/Injectio.Generators/Infrastructure/EquatableArray.cs b/src/Injectio.Generators/Infrastructure/EquatableArray.cs new file mode 100644 index 0000000..c632554 --- /dev/null +++ b/src/Injectio.Generators/Infrastructure/EquatableArray.cs @@ -0,0 +1,128 @@ +using System.Collections; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace Injectio.Generators.Infrastructure; + + +/// +/// Factory for creating instances from spans. +/// +[ExcludeFromCodeCoverage] +internal static class EquatableArray +{ + /// + /// Creates an from a . + /// + /// The element type. + /// The span of items to copy. + /// A new containing the items. + public static EquatableArray Create(ReadOnlySpan items) + where T : IEquatable + => new(items.ToArray()); +} + +/// +/// An immutable array wrapper that implements value-based equality, suitable for use +/// in incremental generator models where structural comparison is required. +/// +/// The element type, which must implement . +[ExcludeFromCodeCoverage] +[CollectionBuilder(typeof(EquatableArray), nameof(EquatableArray.Create))] +public readonly struct EquatableArray : IEquatable>, IEnumerable + where T : IEquatable +{ + /// + /// An empty instance. + /// + public static readonly EquatableArray Empty = new(); + + + /// + /// Initializes a new empty . + /// + public EquatableArray() : this([]) { } + + /// + /// Initializes a new from an array. + /// + /// The array to wrap. + public EquatableArray(T[] array) => Array = array ?? []; + + /// + /// Initializes a new from an enumerable. + /// + /// The items to copy into the array. + public EquatableArray(IEnumerable items) => Array = items.ToArray() ?? []; + + + /// + /// Gets the underlying array. + /// + public T[] Array { get; } + + /// + /// Gets the number of elements in the array. + /// + public int Count => Array.Length; + + + /// + /// Returns the array as a . + /// + /// A read-only span over the array elements. + public ReadOnlySpan AsSpan() => Array.AsSpan(); + + /// + /// Returns the underlying array. + /// + /// The underlying array. + public T[] AsArray() => Array; + + + /// + public static bool operator ==(EquatableArray left, EquatableArray right) => left.Equals(right); + + /// + public static bool operator !=(EquatableArray left, EquatableArray right) => !left.Equals(right); + + /// + public bool Equals(EquatableArray array) => Array.AsSpan().SequenceEqual(array.AsSpan()); + + /// + public override bool Equals(object? obj) => obj is EquatableArray array && Equals(this, array); + + /// + public override int GetHashCode() + { + if (Array is not T[] array) + return 0; + + var hashCode = 16777619; + + for (int i = 0; i < array.Length; i++) + hashCode = unchecked((hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(array[i])); + + return hashCode; + } + + + /// + IEnumerator IEnumerable.GetEnumerator() => (Array as IEnumerable).GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => Array.GetEnumerator(); + + + /// + /// Implicitly converts an array to an . + /// + /// The array to wrap. + public static implicit operator EquatableArray(T[] array) => new(array); + + /// + /// Implicitly converts a list to an . + /// + /// The list to copy. + public static implicit operator EquatableArray(List items) => new(items); +} diff --git a/src/Injectio.Generators/IndentedStringBuilder.cs b/src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs similarity index 98% rename from src/Injectio.Generators/IndentedStringBuilder.cs rename to src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs index 14ab0d9..0f0e66c 100644 --- a/src/Injectio.Generators/IndentedStringBuilder.cs +++ b/src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs @@ -1,7 +1,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text; -namespace Injectio.Generators; +namespace Injectio.Generators.Infrastructure; /// /// A thin wrapper over that adds indentation to each line built. @@ -352,7 +352,7 @@ private void DoIndent() _indentPending = false; } - private sealed class Indenter : IDisposable + private readonly struct Indenter : IDisposable { private readonly IndentedStringBuilder _stringBuilder; @@ -367,7 +367,7 @@ public void Dispose() => _stringBuilder.DecrementIndent(); } - private sealed class IndentSuspender : IDisposable + private readonly struct IndentSuspender : IDisposable { private readonly IndentedStringBuilder _stringBuilder; private readonly byte _indent; diff --git a/src/Injectio.Generators/IsExternalInit.cs b/src/Injectio.Generators/Infrastructure/IsExternalInit.cs similarity index 66% rename from src/Injectio.Generators/IsExternalInit.cs rename to src/Injectio.Generators/Infrastructure/IsExternalInit.cs index 75cd67b..91eb900 100644 --- a/src/Injectio.Generators/IsExternalInit.cs +++ b/src/Injectio.Generators/Infrastructure/IsExternalInit.cs @@ -1,3 +1,5 @@ +#pragma warning disable IDE0130 // Namespace does not match folder structure + using System.ComponentModel; namespace System.Runtime.CompilerServices; diff --git a/src/Injectio.Generators/Injectio.Generators.csproj b/src/Injectio.Generators/Injectio.Generators.csproj index a686a32..95b64a8 100644 --- a/src/Injectio.Generators/Injectio.Generators.csproj +++ b/src/Injectio.Generators/Injectio.Generators.csproj @@ -3,7 +3,11 @@ netstandard2.0 true - false + true + Injectio + true + false + true cs true @@ -13,4 +17,14 @@ + + + + + + + + + + diff --git a/src/Injectio.Generators/KnownTypes.cs b/src/Injectio.Generators/KnownTypes.cs index 77efad4..38048aa 100644 --- a/src/Injectio.Generators/KnownTypes.cs +++ b/src/Injectio.Generators/KnownTypes.cs @@ -10,19 +10,31 @@ public static class KnownTypes public const string TransientAttributeShortName = "RegisterTransient"; public const string TransientAttributeTypeName = $"{TransientAttributeShortName}Attribute"; public const string TransientAttributeFullName = $"{AbstractionNamespace}.{TransientAttributeTypeName}"; + public const string TransientAttributeFullNameT1 = $"{TransientAttributeFullName}`1"; + public const string TransientAttributeFullNameT2 = $"{TransientAttributeFullName}`2"; public const string SingletonAttributeShortName = "RegisterSingleton"; public const string SingletonAttributeTypeName = $"{SingletonAttributeShortName}Attribute"; public const string SingletonAttributeFullName = $"{AbstractionNamespace}.{SingletonAttributeTypeName}"; + public const string SingletonAttributeFullNameT1 = $"{SingletonAttributeFullName}`1"; + public const string SingletonAttributeFullNameT2 = $"{SingletonAttributeFullName}`2"; public const string ScopedAttributeShortName = "RegisterScoped"; public const string ScopedAttributeTypeName = $"{ScopedAttributeShortName}Attribute"; public const string ScopedAttributeFullName = $"{AbstractionNamespace}.{ScopedAttributeTypeName}"; + public const string ScopedAttributeFullNameT1 = $"{ScopedAttributeFullName}`1"; + public const string ScopedAttributeFullNameT2 = $"{ScopedAttributeFullName}`2"; public const string ModuleAttributeShortName = "RegisterServices"; public const string ModuleAttributeTypeName = $"{ModuleAttributeShortName}Attribute"; public const string ModuleAttributeFullName = $"{AbstractionNamespace}.{ModuleAttributeTypeName}"; + public const string DecoratorAttributeShortName = "RegisterDecorator"; + public const string DecoratorAttributeTypeName = $"{DecoratorAttributeShortName}Attribute"; + public const string DecoratorAttributeFullName = $"{AbstractionNamespace}.{DecoratorAttributeTypeName}"; + public const string DecoratorAttributeFullNameT1 = $"{DecoratorAttributeFullName}`1"; + public const string DecoratorAttributeFullNameT2 = $"{DecoratorAttributeFullName}`2"; + public const string ServiceLifetimeSingletonShortName = "Singleton"; public const string ServiceLifetimeSingletonTypeName = $"ServiceLifetime.{ServiceLifetimeSingletonShortName}"; diff --git a/src/Injectio.Generators/Models/DecoratorRegistration.cs b/src/Injectio.Generators/Models/DecoratorRegistration.cs new file mode 100644 index 0000000..4a65526 --- /dev/null +++ b/src/Injectio.Generators/Models/DecoratorRegistration.cs @@ -0,0 +1,14 @@ +using Injectio.Generators.Infrastructure; + +namespace Injectio.Generators.Models; + +public record DecoratorRegistration( + string DecoratorType, + string ServiceType, + string? ServiceKey, + bool IsAnyKey, + string? Factory, + int Order, + EquatableArray Tags, + bool IsOpenGeneric = false +); diff --git a/src/Injectio.Generators/MethodOptions.cs b/src/Injectio.Generators/Models/MethodOptions.cs similarity index 65% rename from src/Injectio.Generators/MethodOptions.cs rename to src/Injectio.Generators/Models/MethodOptions.cs index dc2aabb..01b3b75 100644 --- a/src/Injectio.Generators/MethodOptions.cs +++ b/src/Injectio.Generators/Models/MethodOptions.cs @@ -1,4 +1,4 @@ -namespace Injectio.Generators; +namespace Injectio.Generators.Models; public record MethodOptions( string? Name, diff --git a/src/Injectio.Generators/ModuleRegistration.cs b/src/Injectio.Generators/Models/ModuleRegistration.cs similarity index 77% rename from src/Injectio.Generators/ModuleRegistration.cs rename to src/Injectio.Generators/Models/ModuleRegistration.cs index b8ea6ce..a3e36b2 100644 --- a/src/Injectio.Generators/ModuleRegistration.cs +++ b/src/Injectio.Generators/Models/ModuleRegistration.cs @@ -1,4 +1,4 @@ -namespace Injectio.Generators; +namespace Injectio.Generators.Models; public record ModuleRegistration( string ClassName, diff --git a/src/Injectio.Generators/Models/RegistrationContext.cs b/src/Injectio.Generators/Models/RegistrationContext.cs new file mode 100644 index 0000000..e4b8d10 --- /dev/null +++ b/src/Injectio.Generators/Models/RegistrationContext.cs @@ -0,0 +1,11 @@ +using Injectio.Generators.Infrastructure; + +namespace Injectio.Generators.Models; + +public record RegistrationContext( + EquatableArray ServiceRegistrations, + EquatableArray ModuleRegistrations, + EquatableArray DecoratorRegistrations, + string? AssemblyName, + MethodOptions? MethodOptions +); diff --git a/src/Injectio.Generators/ServiceRegistration.cs b/src/Injectio.Generators/Models/ServiceRegistration.cs similarity index 80% rename from src/Injectio.Generators/ServiceRegistration.cs rename to src/Injectio.Generators/Models/ServiceRegistration.cs index 7be2c98..5eec59f 100644 --- a/src/Injectio.Generators/ServiceRegistration.cs +++ b/src/Injectio.Generators/Models/ServiceRegistration.cs @@ -1,6 +1,8 @@ using System.Diagnostics.CodeAnalysis; -namespace Injectio.Generators; +using Injectio.Generators.Infrastructure; + +namespace Injectio.Generators.Models; public record ServiceRegistration( string Lifetime, diff --git a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs index dbec02e..a4600ee 100644 --- a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs +++ b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs @@ -20,7 +20,15 @@ public class ServiceRegistrationAnalyzer : DiagnosticAnalyzer DiagnosticDescriptors.FactoryMethodInvalidSignature, DiagnosticDescriptors.ServiceTypeMismatch, DiagnosticDescriptors.AbstractImplementationType, - DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass); + DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass, + DiagnosticDescriptors.DecoratorDoesNotImplementService, + DiagnosticDescriptors.DecoratorMissingServiceType, + DiagnosticDescriptors.DecoratorMissingInnerConstructor, + DiagnosticDescriptors.DecoratorFactoryNotFound, + DiagnosticDescriptors.DecoratorFactoryInvalidSignature, + DiagnosticDescriptors.DecoratorOpenGenericKeyed, + DiagnosticDescriptors.DecoratorTargetNotRegistered + ); public override void Initialize(AnalysisContext context) { @@ -29,6 +37,234 @@ public override void Initialize(AnalysisContext context) context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method); context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType); + context.RegisterCompilationStartAction(AnalyzeCompilation); + } + + private static void AnalyzeCompilation(CompilationStartAnalysisContext context) + { + var decorators = new List(); + var registeredServices = new HashSet(StringComparer.Ordinal); + var hasModule = false; + + context.RegisterSymbolAction(symbolContext => + { + if (symbolContext.Symbol is IMethodSymbol methodSymbol) + { + foreach (var attribute in methodSymbol.GetAttributes()) + { + if (SymbolHelpers.IsMethodAttribute(attribute)) + { + lock (registeredServices) hasModule = true; + break; + } + } + return; + } + + if (symbolContext.Symbol is not INamedTypeSymbol classSymbol) + return; + + if (classSymbol.IsStatic) + return; + + var attributes = classSymbol.GetAttributes(); + + foreach (var attribute in attributes) + { + if (SymbolHelpers.IsDecoratorAttribute(attribute)) + { + var info = ExtractDecoratorInfo(classSymbol, attribute); + if (info != null) + lock (decorators) decorators.Add(info); + continue; + } + + if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) + continue; + + CollectRegisteredServiceTypes(classSymbol, attribute, registeredServices); + } + }, SymbolKind.NamedType, SymbolKind.Method); + + context.RegisterCompilationEndAction(endContext => + { + foreach (var decorator in decorators) + { + // INJ0015 — open-generic + keyed combination + if (decorator.IsOpenGeneric && (decorator.HasServiceKey || decorator.IsAnyKey)) + { + endContext.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorOpenGenericKeyed, + decorator.Location, + decorator.DecoratorType, + decorator.ServiceType)); + } + + // INJ0016 — target not registered (only when no module could register it dynamically) + if (!hasModule && !decorator.HasServiceKey && !decorator.IsAnyKey) + { + if (!registeredServices.Contains(decorator.ServiceType)) + { + endContext.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorTargetNotRegistered, + decorator.Location, + decorator.DecoratorType, + decorator.ServiceType)); + } + } + } + }); + } + + private sealed class DecoratorInfo + { + public string DecoratorType { get; set; } = string.Empty; + public string ServiceType { get; set; } = string.Empty; + public bool IsOpenGeneric { get; set; } + public bool HasServiceKey { get; set; } + public bool IsAnyKey { get; set; } + public Location Location { get; set; } = Location.None; + } + + private static DecoratorInfo? ExtractDecoratorInfo(INamedTypeSymbol classSymbol, AttributeData attribute) + { + string? serviceType = null; + bool hasServiceKey = false; + bool isAnyKey = false; + bool isOpenGeneric = false; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1) + { + var typeArgument = attributeClass.TypeArguments[0]; + if (typeArgument is INamedTypeSymbol namedService && namedService.IsGenericType && namedService.IsUnboundGenericType) + isOpenGeneric = true; + + serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + { + if (svc.IsGenericType && svc.IsUnboundGenericType) + isOpenGeneric = true; + serviceType = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + break; + case "ServiceKey": + hasServiceKey = value is not null; + break; + case "AnyKey": + if (value is bool b) + isAnyKey = b; + break; + } + } + + if (serviceType is null) + return null; + + var location = classSymbol.Locations.Length > 0 ? classSymbol.Locations[0] : Location.None; + + return new DecoratorInfo + { + DecoratorType = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat), + ServiceType = serviceType, + IsOpenGeneric = isOpenGeneric, + HasServiceKey = hasServiceKey, + IsAnyKey = isAnyKey, + Location = location, + }; + } + + private static void CollectRegisteredServiceTypes( + INamedTypeSymbol classSymbol, + AttributeData attribute, + HashSet registeredServices) + { + string? implementationType = null; + string? registrationStrategy = null; + var localServiceTypes = new List(); + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length) + { + for (var index = 0; index < attributeClass.TypeParameters.Length; index++) + { + var typeParameter = attributeClass.TypeParameters[index]; + var typeArgument = attributeClass.TypeArguments[index]; + + if (typeParameter.Name == "TService" || index == 0) + localServiceTypes.Add(typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + else if (typeParameter.Name == "TImplementation" || index == 1) + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + if (value is null) continue; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + localServiceTypes.Add(svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + break; + case "ImplementationType": + if (value is INamedTypeSymbol impl) + implementationType = impl.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + break; + case "Registration": + registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value); + break; + } + } + + implementationType ??= classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + + if (registrationStrategy == null && localServiceTypes.Count == 0) + registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + or null; + + if (includeInterfaces) + { + foreach (var iface in classSymbol.AllInterfaces) + { + if (iface.ConstructedFrom.ToString() == "System.IEquatable") + continue; + + var unbound = SymbolHelpers.ToUnboundGenericType(iface); + localServiceTypes.Add(unbound.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + } + } + + bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + or null; + + if (includeSelf || localServiceTypes.Count == 0) + localServiceTypes.Add(implementationType); + + lock (registeredServices) + { + foreach (var t in localServiceTypes) + registeredServices.Add(t); + } } private static void AnalyzeMethod(SymbolAnalysisContext context) @@ -126,17 +362,249 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context) foreach (var attribute in attributes) { - if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) - continue; - var location = classSymbol.Locations.Length > 0 ? classSymbol.Locations[0] : Location.None; + if (SymbolHelpers.IsDecoratorAttribute(attribute)) + { + AnalyzeDecoratorAttribute(context, classSymbol, attribute, location); + continue; + } + + if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) + continue; + AnalyzeRegistrationAttribute(context, classSymbol, attribute, location); } } + private static void AnalyzeDecoratorAttribute( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + AttributeData attribute, + Location location) + { + string? serviceTypeName = null; + INamedTypeSymbol? serviceTypeSymbol = null; + string? factory = null; + bool hasServiceKey = false; + bool isAnyKey = false; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1) + { + if (attributeClass.TypeArguments[0] is INamedTypeSymbol serviceArg) + { + serviceTypeSymbol = serviceArg; + serviceTypeName = serviceArg.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + { + serviceTypeSymbol = svc; + serviceTypeName = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + break; + case "Factory": + factory = value?.ToString(); + break; + case "ServiceKey": + hasServiceKey = value is not null; + break; + case "AnyKey": + if (value is bool b) + isAnyKey = b; + break; + } + } + + // INJ0011 — missing service type + if (serviceTypeName is null) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorMissingServiceType, + location, + classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat))); + return; + } + + // INJ0010 — class does not implement service + var classTypeName = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (serviceTypeName != classTypeName) + { + var implementsService = false; + + foreach (var iface in classSymbol.AllInterfaces) + { + var ifaceName = iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (ifaceName == serviceTypeName) + { + implementsService = true; + break; + } + + var unboundIface = SymbolHelpers.ToUnboundGenericType(iface); + if (!SymbolEqualityComparer.Default.Equals(unboundIface, iface)) + { + var unboundName = unboundIface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundName == serviceTypeName) + { + implementsService = true; + break; + } + } + } + + if (!implementsService) + { + var baseType = classSymbol.BaseType; + while (baseType is not null) + { + var baseName = baseType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (baseName == serviceTypeName) { implementsService = true; break; } + + var unboundBase = SymbolHelpers.ToUnboundGenericType(baseType); + if (!SymbolEqualityComparer.Default.Equals(unboundBase, baseType)) + { + var unboundBaseName = unboundBase.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundBaseName == serviceTypeName) { implementsService = true; break; } + } + + baseType = baseType.BaseType; + } + } + + if (!implementsService) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorDoesNotImplementService, + location, + classTypeName, + serviceTypeName)); + } + } + + // INJ0012/13/14 — constructor or factory validation + if (factory.HasValue()) + { + ValidateDecoratorFactory(context, classSymbol, factory!, hasServiceKey || isAnyKey, location); + } + else + { + var hasCompatibleCtor = false; + foreach (var ctor in classSymbol.InstanceConstructors) + { + if (ctor.DeclaredAccessibility == Accessibility.Private) continue; + if (ctor.Parameters.Length == 0) continue; + + var firstParamType = ctor.Parameters[0].Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (firstParamType == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + + if (ctor.Parameters[0].Type is INamedTypeSymbol paramNamed) + { + var unboundParam = SymbolHelpers.ToUnboundGenericType(paramNamed); + if (!SymbolEqualityComparer.Default.Equals(unboundParam, paramNamed)) + { + var unboundName = unboundParam.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundName == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + } + } + + // any parameter matches? + foreach (var param in ctor.Parameters) + { + var paramType = param.Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (paramType == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + } + + if (hasCompatibleCtor) break; + } + + if (!hasCompatibleCtor && classSymbol.InstanceConstructors.Length > 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorMissingInnerConstructor, + location, + classTypeName, + serviceTypeName)); + } + } + } + + private static void ValidateDecoratorFactory( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + string factoryMethodName, + bool isKeyed, + Location location) + { + var className = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + var members = classSymbol.GetMembers(factoryMethodName); + var factoryMethods = new List(); + + foreach (var member in members) + { + if (member is IMethodSymbol method) + factoryMethods.Add(method); + } + + if (factoryMethods.Count == 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorFactoryNotFound, + location, + factoryMethodName, + className)); + return; + } + + var expectedParamCount = isKeyed ? 3 : 2; + + foreach (var method in factoryMethods) + { + if (!method.IsStatic) continue; + if (method.Parameters.Length != expectedParamCount) continue; + + if (!SymbolHelpers.IsServiceProvider(method.Parameters[0])) continue; + + if (isKeyed) + { + if (method.Parameters[1].Type.SpecialType != SpecialType.System_Object) continue; + // parameter[2] is the inner service — not strictly checked + } + + return; // valid overload found + } + + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorFactoryInvalidSignature, + location, + factoryMethodName, + className)); + } + private static void AnalyzeRegistrationAttribute( SymbolAnalysisContext context, INamedTypeSymbol classSymbol, diff --git a/src/Injectio.Generators/ServiceRegistrationContext.cs b/src/Injectio.Generators/ServiceRegistrationContext.cs deleted file mode 100644 index 2451cce..0000000 --- a/src/Injectio.Generators/ServiceRegistrationContext.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System.Diagnostics.CodeAnalysis; - -using Microsoft.CodeAnalysis; - -namespace Injectio.Generators; - -public record ServiceRegistrationContext( - EquatableArray? ServiceRegistrations = null, - EquatableArray? ModuleRegistrations = null -); diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs index d4be12c..717d348 100644 --- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs +++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs @@ -1,8 +1,11 @@ using System.Collections.Immutable; +using System.IO; using System.Text; using System.Text.RegularExpressions; using Injectio.Generators.Extensions; +using Injectio.Generators.Infrastructure; +using Injectio.Generators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -16,18 +19,111 @@ public class ServiceRegistrationGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - // find all classes and methods with attributes - var registrations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: SyntacticPredicate, - transform: SemanticTransform - ) - .Where(static context => - context is not null - && (context.ServiceRegistrations?.Count > 0 || context.ModuleRegistrations?.Count > 0) - ) - .Collect() - .WithTrackingName("Registrations"); + context.RegisterPostInitializationOutput(static ctx => + { + ctx.AddEmbeddedAttributeDefinition(); + + ctx.AddSource( + hintName: "Injectio.Attributes.g.cs", + sourceText: SourceText.From(LoadEmbedded("Attributes.cs"), Encoding.UTF8) + ); + ctx.AddSource( + hintName: "Injectio.Extensions.g.cs", + sourceText: SourceText.From(LoadEmbedded("Extensions.cs"), Encoding.UTF8) + ); + }); + + // separate pipeline per attribute, each returning its specific model + var transientRegistrations = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.TransientAttributeFullName, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName)) + .WithTrackingName("TransientRegistrations"); + + var transientRegistrationsT1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.TransientAttributeFullNameT1, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName)) + .WithTrackingName("TransientRegistrationsT1"); + + var transientRegistrationsT2 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.TransientAttributeFullNameT2, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName)) + .WithTrackingName("TransientRegistrationsT2"); + + var scopedRegistrations = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.ScopedAttributeFullName, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName)) + .WithTrackingName("ScopedRegistrations"); + + var scopedRegistrationsT1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.ScopedAttributeFullNameT1, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName)) + .WithTrackingName("ScopedRegistrationsT1"); + + var scopedRegistrationsT2 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.ScopedAttributeFullNameT2, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName)) + .WithTrackingName("ScopedRegistrationsT2"); + + var singletonRegistrations = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.SingletonAttributeFullName, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName)) + .WithTrackingName("SingletonRegistrations"); + + var singletonRegistrationsT1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.SingletonAttributeFullNameT1, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName)) + .WithTrackingName("SingletonRegistrationsT1"); + + var singletonRegistrationsT2 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.SingletonAttributeFullNameT2, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName)) + .WithTrackingName("SingletonRegistrationsT2"); + + var decoratorRegistrations = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.DecoratorAttributeFullName, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformDecoratorRegistration(ctx)) + .WithTrackingName("DecoratorRegistrations"); + + var decoratorRegistrationsT1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.DecoratorAttributeFullNameT1, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformDecoratorRegistration(ctx)) + .WithTrackingName("DecoratorRegistrationsT1"); + + var decoratorRegistrationsT2 = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.DecoratorAttributeFullNameT2, + predicate: static (node, _) => IsNonAbstractNonStaticType(node), + transform: static (ctx, _) => TransformDecoratorRegistration(ctx)) + .WithTrackingName("DecoratorRegistrationsT2"); + + var moduleRegistrations = context.SyntaxProvider + .ForAttributeWithMetadataName( + KnownTypes.ModuleAttributeFullName, + predicate: static (node, _) => node is MethodDeclarationSyntax, + transform: static (ctx, _) => TransformModuleRegistration(ctx)) + .WithTrackingName("ModuleRegistrations"); // include compilation options var assemblyName = context.CompilationProvider @@ -35,7 +131,7 @@ context is not null .WithTrackingName("AssemblyName"); // include config options - var methodName = context.AnalyzerConfigOptionsProvider + var methodOptions = context.AnalyzerConfigOptionsProvider .Select(static (c, _) => { c.GlobalOptions.TryGetValue("build_property.InjectioName", out var methodName); @@ -44,38 +140,98 @@ context is not null }) .WithTrackingName("Options"); - var options = assemblyName.Combine(methodName); - var generation = registrations.Combine(options); + // combine all service registration pipelines + var allServiceRegistrations = transientRegistrations + .Collect() + .Combine(transientRegistrationsT1.Collect()) + .Combine(transientRegistrationsT2.Collect()) + .Combine(scopedRegistrations.Collect()) + .Combine(scopedRegistrationsT1.Collect()) + .Combine(scopedRegistrationsT2.Collect()) + .Combine(singletonRegistrations.Collect()) + .Combine(singletonRegistrationsT1.Collect()) + .Combine(singletonRegistrationsT2.Collect()) + .Select(static (combined, _) => + { + var ((((((((t, t1), t2), s), s1), s2), si), si1), si2) = combined; + var result = new List(); + foreach (var arr in t) result.AddRange(arr); + foreach (var arr in t1) result.AddRange(arr); + foreach (var arr in t2) result.AddRange(arr); + foreach (var arr in s) result.AddRange(arr); + foreach (var arr in s1) result.AddRange(arr); + foreach (var arr in s2) result.AddRange(arr); + foreach (var arr in si) result.AddRange(arr); + foreach (var arr in si1) result.AddRange(arr); + foreach (var arr in si2) result.AddRange(arr); + return new EquatableArray(result); + }) + .WithTrackingName("AllServiceRegistrations"); + + // combine all decorator registration pipelines + var allDecoratorRegistrations = decoratorRegistrations + .Collect() + .Combine(decoratorRegistrationsT1.Collect()) + .Combine(decoratorRegistrationsT2.Collect()) + .Select(static (combined, _) => + { + var ((d, d1), d2) = combined; + var result = new List(); + foreach (var arr in d) result.AddRange(arr.AsArray()); + foreach (var arr in d1) result.AddRange(arr.AsArray()); + foreach (var arr in d2) result.AddRange(arr.AsArray()); + return new EquatableArray(result); + }) + .WithTrackingName("AllDecoratorRegistrations"); + + // combine all pipelines into a single generation context + var generation = allServiceRegistrations + .Combine(moduleRegistrations.Collect()) + .Combine(allDecoratorRegistrations) + .Combine(assemblyName) + .Combine(methodOptions) + .Select(static (combined, _) => + { + var ((((serviceRegs, moduleRegs), decoratorRegs), asmName), options) = combined; + + return new RegistrationContext( + ServiceRegistrations: serviceRegs, + ModuleRegistrations: new EquatableArray(moduleRegs.Where(static m => m is not null).Cast().ToList()), + DecoratorRegistrations: decoratorRegs, + AssemblyName: asmName, + MethodOptions: options + ); + }) + .WithTrackingName("Generation"); context.RegisterSourceOutput(generation, ExecuteGeneration); } - private void ExecuteGeneration( - SourceProductionContext sourceContext, - (ImmutableArray Registrations, (string? AssemblyName, MethodOptions? MethodOptions) Options) source) + private void ExecuteGeneration(SourceProductionContext sourceContext, RegistrationContext source) { - var serviceRegistrations = source.Registrations - .SelectMany(m => m?.ServiceRegistrations ?? Array.Empty()) - .Where(m => m is not null) - .ToArray(); + var serviceRegistrations = source.ServiceRegistrations.AsArray(); - var moduleRegistrations = source.Registrations - .SelectMany(m => m?.ModuleRegistrations ?? Array.Empty()) - .Where(m => m is not null) + var moduleRegistrations = source.ModuleRegistrations.AsArray(); + + var decoratorRegistrations = source.DecoratorRegistrations + .OrderBy(m => m.ServiceType, StringComparer.Ordinal) + .ThenBy(m => m.Order) + .ThenBy(m => m.DecoratorType, StringComparer.Ordinal) .ToArray(); // compute extension method name - var methodName = source.Options.MethodOptions?.Name; + var methodName = source.MethodOptions?.Name; if (methodName.IsNullOrWhiteSpace()) - methodName = Regex.Replace(source.Options.AssemblyName, "\\W", ""); + methodName = Regex.Replace(source.AssemblyName ?? string.Empty, "\\W", ""); - var methodInternal = source.Options.MethodOptions?.Internal; + var methodInternal = source.MethodOptions?.Internal; // generate registration method var result = ServiceRegistrationWriter.GenerateExtensionClass( moduleRegistrations, serviceRegistrations, - source.Options.AssemblyName, + decoratorRegistrations, + source.AssemblyName, methodName, methodInternal); @@ -83,91 +239,184 @@ private void ExecuteGeneration( sourceContext.AddSource("Injectio.g.cs", SourceText.From(result, Encoding.UTF8)); } - private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken cancellationToken) + private static string LoadEmbedded(string name) { - return syntaxNode switch - { - ClassDeclarationSyntax { AttributeLists.Count: > 0 } declaration => - !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword) - && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword), - - RecordDeclarationSyntax { AttributeLists.Count: > 0 } declaration => - !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword) - && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword), - - MemberDeclarationSyntax { AttributeLists.Count: > 0 } declaration => - !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword), + var assembly = typeof(ServiceRegistrationGenerator).Assembly; + var resourceName = "Injectio.Generators.Embedded." + name; + using var stream = assembly.GetManifestResourceStream(resourceName) + ?? throw new InvalidOperationException($"Embedded resource not found: {resourceName}"); + using var reader = new StreamReader(stream); + return reader.ReadToEnd(); + } - _ => false, - }; + private static bool IsNonAbstractNonStaticType(SyntaxNode node) + { + return node is TypeDeclarationSyntax declaration + and (ClassDeclarationSyntax or RecordDeclarationSyntax) + && !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword) + && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword); } - private static ServiceRegistrationContext? SemanticTransform(GeneratorSyntaxContext context, CancellationToken cancellationToken) + private static EquatableArray TransformServiceRegistration(GeneratorAttributeSyntaxContext context, string serviceLifetime) { - return context.Node switch + if (context.TargetSymbol is not INamedTypeSymbol classSymbol) + return EquatableArray.Empty; + + // find matching attributes for this specific lifetime + var results = new List(); + foreach (var attribute in context.Attributes) { - ClassDeclarationSyntax => SemanticTransformClass(context), - RecordDeclarationSyntax => SemanticTransformClass(context), - MethodDeclarationSyntax => SemanticTransformMethod(context), - _ => null - }; + var registration = CreateServiceRegistration(classSymbol, attribute, serviceLifetime); + if (registration is not null) + results.Add(registration); + } + + return new EquatableArray(results); } - private static ServiceRegistrationContext? SemanticTransformMethod(GeneratorSyntaxContext context) + private static EquatableArray TransformDecoratorRegistration(GeneratorAttributeSyntaxContext context) { - if (context.Node is not MethodDeclarationSyntax methodDeclaration) - return null; + if (context.TargetSymbol is not INamedTypeSymbol classSymbol) + return EquatableArray.Empty; - var methodSymbol = context.SemanticModel.GetDeclaredSymbol(methodDeclaration); - if (methodSymbol is null) - return null; + var results = new List(); + foreach (var attribute in context.Attributes) + { + var decorator = CreateDecoratorRegistration(classSymbol, attribute); + if (decorator is not null) + results.Add(decorator); + } - // make sure attribute is for registration - var attributes = methodSymbol.GetAttributes(); - var isKnown = attributes.Any(SymbolHelpers.IsMethodAttribute); - if (!isKnown) + return new EquatableArray(results); + } + + private static ModuleRegistration? TransformModuleRegistration(GeneratorAttributeSyntaxContext context) + { + if (context.TargetSymbol is not IMethodSymbol methodSymbol) return null; var (isValid, hasTagCollection) = ValidateMethod(methodSymbol); if (!isValid) return null; - var registration = new ModuleRegistration - ( + return new ModuleRegistration( ClassName: methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat), MethodName: methodSymbol.Name, IsStatic: methodSymbol.IsStatic, HasTagCollection: hasTagCollection ); - - return new ServiceRegistrationContext(ModuleRegistrations: new[] { registration }); } - private static ServiceRegistrationContext? SemanticTransformClass(GeneratorSyntaxContext context) + private static DecoratorRegistration? CreateDecoratorRegistration(INamedTypeSymbol classSymbol, AttributeData attribute) { - if (context.Node is not (TypeDeclarationSyntax declaration and (ClassDeclarationSyntax or RecordDeclarationSyntax))) - return null; + string? serviceType = null; + string? implementationType = null; + string? serviceKey = null; + bool isAnyKey = false; + string? factory = null; + int order = 0; + var tags = new HashSet(); + bool isOpenGeneric = false; - var classSymbol = context.SemanticModel.GetDeclaredSymbol(declaration); - if (classSymbol is null) - return null; + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length) + { + for (var index = 0; index < attributeClass.TypeParameters.Length; index++) + { + var typeParameter = attributeClass.TypeParameters[index]; + var typeArgument = attributeClass.TypeArguments[index]; - var attributes = classSymbol.GetAttributes(); + if (typeParameter.Name == "TService" || index == 0) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol); + serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (typeParameter.Name == "TImplementation" || index == 1) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol); + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + if (string.IsNullOrEmpty(name)) + continue; - // support multiple register attributes on a class - var registrations = new List(); + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol serviceTypeSymbol) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(serviceTypeSymbol); + serviceType = serviceTypeSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (value != null) + { + serviceType = value.ToString(); + } + break; + case "ImplementationType": + if (value is INamedTypeSymbol implSymbol) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(implSymbol); + implementationType = implSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (value != null) + { + implementationType = value.ToString(); + } + break; + case "ServiceKey": + serviceKey = parameter.Value.ToCSharpString(); + break; + case "AnyKey": + if (value is bool anyKey) + isAnyKey = anyKey; + break; + case "Factory": + factory = value?.ToString(); + break; + case "Order": + if (value is int orderValue) + order = orderValue; + break; + case "Tags": + if (value is string tagsText) + { + foreach (var tag in tagsText.Split(',', ';')) + { + if (tag.HasValue()) + tags.Add(tag.Trim()); + } + } + break; + } + } - foreach (var attribute in attributes) + if (implementationType.IsNullOrWhiteSpace()) { - var registration = CreateServiceRegistration(classSymbol, attribute); - if (registration is not null) - registrations.Add(registration); + var unboundType = SymbolHelpers.ToUnboundGenericType(classSymbol); + isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundType); + implementationType = unboundType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); } - if (registrations.Count == 0) + if (serviceType.IsNullOrWhiteSpace()) return null; - return new ServiceRegistrationContext(ServiceRegistrations: registrations.ToArray()); + return new DecoratorRegistration( + DecoratorType: implementationType!, + ServiceType: serviceType!, + ServiceKey: serviceKey, + IsAnyKey: isAnyKey, + Factory: factory, + Order: order, + Tags: tags.ToArray(), + IsOpenGeneric: isOpenGeneric); } private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbol methodSymbol) @@ -199,12 +448,8 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo return (false, false); } - private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute) + private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute, string serviceLifetime) { - // check for known attribute - if (!SymbolHelpers.IsKnownAttribute(attribute, out var serviceLifetime)) - return null; - // defaults var serviceTypes = new HashSet(); string implementationType = null!; @@ -308,6 +553,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName or KnownTypes.RegistrationStrategySelfWithInterfacesShortName or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + if (includeInterfaces) { foreach (var implementedInterface in classSymbol.AllInterfaces) @@ -328,6 +574,7 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName or KnownTypes.RegistrationStrategySelfWithInterfacesShortName or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + if (includeSelf || serviceTypes.Count == 0) serviceTypes.Add(implementationType!); diff --git a/src/Injectio.Generators/ServiceRegistrationWriter.cs b/src/Injectio.Generators/ServiceRegistrationWriter.cs index a393807..26f9aa6 100644 --- a/src/Injectio.Generators/ServiceRegistrationWriter.cs +++ b/src/Injectio.Generators/ServiceRegistrationWriter.cs @@ -1,4 +1,6 @@ using Injectio.Generators.Extensions; +using Injectio.Generators.Infrastructure; +using Injectio.Generators.Models; namespace Injectio.Generators; @@ -11,6 +13,15 @@ public static string GenerateExtensionClass( string? assemblyName, string? methodName, string? methodInternal) + => GenerateExtensionClass(moduleRegistrations, serviceRegistrations, Array.Empty(), assemblyName, methodName, methodInternal); + + public static string GenerateExtensionClass( + IReadOnlyList moduleRegistrations, + IReadOnlyList serviceRegistrations, + IReadOnlyList decoratorRegistrations, + string? assemblyName, + string? methodName, + string? methodInternal) { var codeBuilder = new IndentedStringBuilder(); codeBuilder @@ -66,6 +77,11 @@ public static string GenerateExtensionClass( WriteRegistration(codeBuilder, serviceRegistration); } + foreach (var decoratorRegistration in decoratorRegistrations) + { + WriteDecorator(codeBuilder, decoratorRegistration); + } + codeBuilder .AppendLine("return serviceCollection;") .DecrementIndent() @@ -320,6 +336,165 @@ private static void WriteServiceGeneric( .AppendLine(); } + private static void WriteDecorator( + IndentedStringBuilder codeBuilder, + DecoratorRegistration decorator) + { + if (decorator.Tags.Count > 0) + { + codeBuilder + .Append("if (tagSet.Count == 0 || tagSet.Intersect(new[] { "); + + bool wroteTag = false; + foreach (var tag in decorator.Tags) + { + if (wroteTag) + codeBuilder.Append(", "); + + codeBuilder + .Append("\"") + .Append(tag) + .Append("\""); + + wroteTag = true; + } + + codeBuilder + .AppendLine(" }).Any())") + .AppendLine("{") + .IncrementIndent(); + } + + var serviceType = decorator.ServiceType; + var decoratorType = decorator.DecoratorType; + bool hasServiceKey = decorator.ServiceKey.HasValue(); + bool isKeyed = hasServiceKey || decorator.IsAnyKey; + + // resolve the service key expression passed to the helper + string keyExpression; + if (decorator.IsAnyKey) + keyExpression = "global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey"; + else if (hasServiceKey) + keyExpression = decorator.ServiceKey!; + else + keyExpression = "null"; + + if (decorator.IsOpenGeneric) + { + codeBuilder + .Append("global::Injectio.Extensions.DecorationExtensions.DecorateOpenGeneric(") + .AppendLine() + .IncrementIndent() + .AppendLine("serviceCollection,") + .Append("typeof(") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine("),") + .Append("typeof(") + .AppendIf("global::", !decoratorType.StartsWith("global::")) + .Append(decoratorType) + .AppendLine(")") + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + else if (isKeyed) + { + codeBuilder + .Append("global::Injectio.Extensions.DecorationExtensions.DecorateKeyed<") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine(">(") + .IncrementIndent() + .AppendLine("serviceCollection,") + .Append(keyExpression) + .AppendLine(","); + + WriteDecoratorFactory(codeBuilder, decorator, isKeyed: true); + + codeBuilder + .AppendLine() + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + else + { + codeBuilder + .Append("global::Injectio.Extensions.DecorationExtensions.Decorate<") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine(">(") + .IncrementIndent() + .AppendLine("serviceCollection,"); + + WriteDecoratorFactory(codeBuilder, decorator, isKeyed: false); + + codeBuilder + .AppendLine() + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + + if (decorator.Tags.Count > 0) + { + codeBuilder + .DecrementIndent() + .AppendLine("}") + .AppendLine(); + } + } + + private static void WriteDecoratorFactory( + IndentedStringBuilder codeBuilder, + DecoratorRegistration decorator, + bool isKeyed) + { + var serviceType = decorator.ServiceType; + var decoratorType = decorator.DecoratorType; + var qualifiedService = serviceType.StartsWith("global::") ? serviceType : "global::" + serviceType; + var qualifiedDecorator = decoratorType.StartsWith("global::") ? decoratorType : "global::" + decoratorType; + + if (decorator.Factory.HasValue()) + { + bool hasNamespace = decorator.Factory!.Contains("."); + var factoryTarget = hasNamespace ? decorator.Factory! : qualifiedDecorator + "." + decorator.Factory; + + if (isKeyed) + { + codeBuilder + .Append("static (serviceProvider, serviceKey, inner) => ") + .Append(factoryTarget) + .Append("(serviceProvider, serviceKey, inner)"); + } + else + { + codeBuilder + .Append("static (serviceProvider, inner) => ") + .Append(factoryTarget) + .Append("(serviceProvider, inner)"); + } + } + else + { + if (isKeyed) + { + codeBuilder + .Append("static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<") + .Append(qualifiedDecorator) + .Append(">(serviceProvider, inner)"); + } + else + { + codeBuilder + .Append("static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<") + .Append(qualifiedDecorator) + .Append(">(serviceProvider, inner)"); + } + } + } + public static string GetServiceCollectionMethod(string duplicateStrategy) { return duplicateStrategy switch diff --git a/src/Injectio.Generators/SymbolHelpers.cs b/src/Injectio.Generators/SymbolHelpers.cs index 54cc867..04cfc4d 100644 --- a/src/Injectio.Generators/SymbolHelpers.cs +++ b/src/Injectio.Generators/SymbolHelpers.cs @@ -61,6 +61,19 @@ public static bool IsScopedAttribute(AttributeData attribute) }; } + public static bool IsDecoratorAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.DecoratorAttributeShortName or KnownTypes.DecoratorAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + public static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime) { if (IsSingletonAttribute(attribute)) diff --git a/src/Injectio.Shared/Attributes.cs b/src/Injectio.Shared/Attributes.cs new file mode 100644 index 0000000..52746b0 --- /dev/null +++ b/src/Injectio.Shared/Attributes.cs @@ -0,0 +1,382 @@ +// +#nullable enable + +namespace Injectio.Attributes +{ + /// + /// Service registration duplicate strategies + /// + [global::Microsoft.CodeAnalysis.Embedded] + internal enum DuplicateStrategy + { + /// + /// Skips registrations for services that already exists. + /// + Skip = 0, + /// + /// Replaces existing service registrations. + /// + Replace = 1, + /// + /// Appends a new registration for existing services. + /// + Append = 2 + } + + /// + /// Register service type strategies + /// + [global::Microsoft.CodeAnalysis.Embedded] + internal enum RegistrationStrategy + { + /// + /// Registers each matching concrete type as itself + /// + Self = 0, + /// + /// Registers each matching concrete type as all of its implemented interfaces. + /// + ImplementedInterfaces = 1, + /// + /// Registers each matching concrete type as all of its implemented interfaces and itself + /// + SelfWithInterfaces = 2, + /// + /// Registers each matching concrete type as all of its implemented interfaces and itself. + /// For the interfaces a proxy-factory resolves the service from its type-name, so only one instance is created per lifetime + /// + /// For open-generic registrations, this behaves like + SelfWithProxyFactory = 3 + } + + /// + /// Attribute to indicate the target class should be register for dependency injection + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal abstract class RegisterAttribute : global::System.Attribute + { + /// + /// The of the service + /// + /// + public global::System.Type? ServiceType { get; set; } + + /// + /// The that implements the service. If not set, the class the interface is on will be used. + /// + /// + public global::System.Type? ImplementationType { get; set; } + + /// + /// Gets or sets the key of the service. + /// + /// The service key. + /// + public object? ServiceKey { get; set; } + + /// + /// Name of a factory method to create new instances of the service implementation + /// + /// + /// The method signature must be Func{IServiceProvider, object} and must be static + /// + /// + public string? Factory { get; set; } + + /// + /// Gets or sets the duplicate. + /// + /// + /// The duplicate. + /// + public DuplicateStrategy Duplicate { get; set; } + + /// + /// Gets or sets the registration. + /// + /// + /// The registration. + /// + public RegistrationStrategy Registration { get; set; } + + /// + /// Gets or sets the comma delimited list of service registration tags. + /// + /// + /// The comma delimited list of service registration tags. + /// + public string? Tags { get; set; } + } + + /// + /// Attribute to indicate the target class should be register for dependency injection as a scoped service + /// + /// Register the class as an implementation for IService + /// + /// [RegisterScoped] + /// public class ScopedService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterScopedAttribute : RegisterAttribute; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a singleton service + /// + /// Register the class as an implementation for IService + /// + /// [RegisterSingleton] + /// public class SingletonService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterSingletonAttribute : RegisterAttribute; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a transient service + /// + /// Register the class as an implementation for IService + /// + /// [RegisterTransient] + /// public class TransientService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterTransientAttribute : RegisterAttribute; + + /// Attribute to indicate the method should be called to register services + /// use the RegisterServices attribute + /// + /// public class RegistrationModule + /// { + /// [RegisterServices] + /// public static void Register(IServiceCollection services) + /// { + /// services.TryAddTransient<IModuleService, ModuleService>(); + /// } + /// } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Method)] + internal class RegisterServicesAttribute : global::System.Attribute; + + /// + /// Attribute to indicate the target class should be registered as a decorator for an existing service. + /// The decorator wraps the previously registered service implementation and inherits its . + /// + /// Decorate IService with a logging wrapper + /// + /// [RegisterDecorator(ServiceType = typeof(IService))] + /// public class LoggingDecorator : IService + /// { + /// public LoggingDecorator(IService inner) { } + /// } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)] + internal class RegisterDecoratorAttribute : global::System.Attribute + { + /// + /// The of the service to decorate. + /// + public global::System.Type? ServiceType { get; set; } + + /// + /// The that implements the decorator. If not set, the class the attribute is on will be used. + /// + public global::System.Type? ImplementationType { get; set; } + + /// + /// Gets or sets the key of the keyed service to decorate. + /// Leave unset (and false) to decorate the non-keyed registration. + /// + public object? ServiceKey { get; set; } + + /// + /// When true, the decorator is applied to every keyed registration of , + /// regardless of its key. Equivalent to KeyedService.AnyKey. + /// + public bool AnyKey { get; set; } + + /// + /// Name of a static factory method to construct the decorator. + /// + /// + /// The factory signature must be (IServiceProvider, TService) -> TService for non-keyed services + /// or (IServiceProvider, object?, TService) -> TService for keyed services. + /// + public string? Factory { get; set; } + + /// + /// Gets or sets the order in which the decorator is applied. Lower values are applied first (innermost). + /// + public int Order { get; set; } + + /// + /// Gets or sets the comma delimited list of registration tags. + /// + public string? Tags { get; set; } + } + +#if NET7_0_OR_GREATER + /// + /// Attribute to indicate the target class should be register for dependency injection as a scoped service + /// + /// The type of the service to add. + /// Register the class as an implementation for IService + /// + /// [RegisterScoped<IService>] + /// public class ScopedService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterScopedAttribute : RegisterScopedAttribute + where TService : class; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a scoped service + /// + /// The type of the service to add. + /// The type of the implementation to use. + /// Register the ScopedService class as an implementation for IService + /// + /// [RegisterScoped<IService, ScopedService>] + /// public class ScopedService: IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterScopedAttribute : RegisterScopedAttribute + where TService : class + where TImplementation : class, TService; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a singleton service + /// + /// The type of the service to add. + /// Register the class as an implementation for IService + /// + /// [RegisterSingleton<IService>] + /// public class SingletonService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterSingletonAttribute : RegisterSingletonAttribute + where TService : class; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a singleton service + /// + /// The type of the service to add. + /// The type of the implementation to use. + /// Register the SingletonService class as an implementation for IService + /// + /// [RegisterSingleton<IService, SingletonService>] + /// public class SingletonService: IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterSingletonAttribute : RegisterSingletonAttribute + where TService : class + where TImplementation : class, TService; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a transient service + /// + /// The type of the service to add. + /// Register the class as an implementation for IService + /// + /// [RegisterTransient<IService>] + /// public class TransientService : IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + internal class RegisterTransientAttribute : RegisterTransientAttribute + where TService : class; + + /// + /// Attribute to indicate the target class should be register for dependency injection as a transient service + /// + /// The type of the service to add. + /// The type of the implementation to use. + /// Register the TransientService class as an implementation for IService + /// + /// [RegisterTransient<IService, TransientService>] + /// public class TransientService: IService { } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)] + internal class RegisterTransientAttribute : RegisterTransientAttribute + where TService : class + where TImplementation : class, TService; + + /// + /// Attribute to indicate the target class should be registered as a decorator for . + /// + /// The type of the service to decorate. + /// + /// + /// [RegisterDecorator<IService>] + /// public class LoggingDecorator : IService + /// { + /// public LoggingDecorator(IService inner) { } + /// } + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)] + internal class RegisterDecoratorAttribute : RegisterDecoratorAttribute + where TService : class; + + /// + /// Attribute to indicate the target class should be registered as a decorator for + /// using as the decorator implementation. + /// + /// The type of the service to decorate. + /// The type of the decorator implementation. + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)] + internal class RegisterDecoratorAttribute : RegisterDecoratorAttribute + where TService : class + where TImplementation : class, TService; +#endif +} diff --git a/src/Injectio.Shared/Extensions.cs b/src/Injectio.Shared/Extensions.cs new file mode 100644 index 0000000..6554b1d --- /dev/null +++ b/src/Injectio.Shared/Extensions.cs @@ -0,0 +1,451 @@ +// +#nullable enable + +namespace Injectio.Extensions +{ + /// + /// Provides extension methods for decorating registered services in an . + /// + /// + /// + /// The Decorator pattern wraps an existing service with additional behavior without modifying the + /// original implementation. Each decorator implements the same interface as the service it wraps, creating + /// a chain of responsibility: + /// + /// + /// These methods work by replacing the + /// in-place within the service collection. The original factory, instance, or implementation type is captured + /// so the inner service can still be resolved. The replacement descriptor preserves the original lifetime. + /// Calling a decoration method multiple times stacks decorators, with the last call becoming the outermost layer. + /// + /// Usage example — decorating with a factory: + /// + /// // Register the original service + /// services.AddScoped<INotificationService, EmailNotificationService>(); + /// + /// // Decorate with logging + /// services.Decorate<INotificationService>( + /// (sp, inner) => new LoggingNotificationService(inner, sp.GetRequiredService<ILogger>())); + /// + /// // Decorate with retry (stacks on top of logging) + /// services.Decorate<INotificationService>( + /// (sp, inner) => new RetryNotificationService(inner)); + /// + /// // Resolution chain: Retry → Logging → Email + /// + /// Usage example — decorating open generics: + /// + /// // Register closed-generic implementations + /// services.AddScoped<IRepository<Order>, OrderRepository>(); + /// services.AddScoped<IRepository<Customer>, CustomerRepository>(); + /// + /// // Decorate all IRepository<T> with caching + /// services.DecorateOpenGeneric( + /// typeof(IRepository<>), typeof(CachingRepository<>)); + /// + /// + [global::Microsoft.CodeAnalysis.Embedded] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal static class DecorationExtensions + { + /// + /// Decorates all registrations of by wrapping each resolved instance with . + /// Keyed service registrations are skipped. + /// + /// The service type to decorate. + /// The service collection containing the registrations to decorate. + /// A factory that receives the service provider and the original instance, and returns the decorated instance. + /// The same instance for chaining. + /// + /// + /// Decoration replaces each existing for + /// in-place. The original descriptor's factory, instance, or implementation type is + /// captured and used to resolve the inner service, which is then passed to . + /// The replacement descriptor preserves the original registration's lifetime. + /// + /// + /// Multiple calls stack: each call wraps whatever factory is currently registered, enabling layered decoration. + /// Only non-keyed registrations are affected; use DecorateKeyed for keyed services. + /// + /// + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + + if (decoratorFactory is null) + throw new global::System.ArgumentNullException(nameof(decoratorFactory)); + + var serviceType = typeof(TService); + int count = services.Count; + + for (int i = 0; i < count; i++) + { + var descriptor = services[i]; + + // Only rewrite matching non-keyed descriptors for the requested service type. + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var lifetime = descriptor.Lifetime; + + if (descriptor.ImplementationInstance is object instance) + { + // Wrap pre-built singleton/instance registrations without re-creating the inner service. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => decoratorFactory(sp, (TService)instance), + lifetime: lifetime + ); + + continue; + } + + if (descriptor.ImplementationFactory is global::System.Func factory) + { + // Preserve factory semantics by resolving inner through the original factory. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => decoratorFactory(sp, (TService)factory(sp)), + lifetime: lifetime + ); + + continue; + } + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + + // Type registrations are activated first, then passed as the inner instance to the decorator. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => decoratorFactory(sp, (TService)global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType))!, + lifetime: lifetime + ); + } + + return services; + } + + /// + /// Decorates all registrations of by wrapping each resolved instance + /// with an automatically constructed instance of . + /// Keyed service registrations are skipped. + /// + /// The service type to decorate. + /// The decorator type. Its constructor must accept a parameter for the inner service; additional parameters are resolved from the service provider. + /// The service collection containing the registrations to decorate. + /// The same instance for chaining. + /// + /// + /// This overload uses + /// to construct the decorator, passing the inner service instance as an explicit argument. Any remaining + /// constructor parameters are resolved from the . + /// + /// + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services) + where TService : class + where TDecorator : class, TService + { + // Delegate to the factory-based overload so stacking and lifetime behavior stay centralized. + return services.Decorate( + (sp, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, inner)); + } + +#if NET8_0_OR_GREATER + /// + /// Decorates all keyed registrations of matching + /// by wrapping each resolved instance with . + /// Pass to match all keys. + /// + /// The service type to decorate. + /// The service collection containing the registrations to decorate. + /// The key to match, or for all keys. + /// A factory that receives the service provider, the key, and the original instance, and returns the decorated instance. + /// The same instance for chaining. + /// + /// + /// This method iterates through all service descriptors for and replaces each + /// keyed registration whose key matches . The original keyed factory, instance, + /// or implementation type is captured and used to resolve the inner service, which is then passed to + /// along with the service provider and key. The replacement descriptor + /// preserves the original registration's lifetime and key. + /// + /// + /// Non-keyed registrations are skipped; use Decorate for those. Multiple calls stack, enabling + /// layered decoration of keyed services. + /// + /// + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + + if (decoratorFactory is null) + throw new global::System.ArgumentNullException(nameof(decoratorFactory)); + + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + int count = services.Count; + + for (int i = 0; i < count; i++) + { + var descriptor = services[i]; + + // Only rewrite matching keyed descriptors for the requested service type/key. + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var lifetime = descriptor.Lifetime; + + if (descriptor.KeyedImplementationInstance is object keyedInstance) + { + // Wrap pre-built keyed instances directly. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, key) => decoratorFactory(sp, key, (TService)keyedInstance), + lifetime: lifetime); + + continue; + } + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + { + // Flow the runtime key through keyed factories to preserve original keyed behavior. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, key) => decoratorFactory(sp, key, (TService)keyedFactory(sp, key)), + lifetime: lifetime); + + continue; + } + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + + // For keyed type registrations, resolve by type and decorate; key remains a descriptor selector. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, key) => decoratorFactory(sp, key, (TService)global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType)), + lifetime: lifetime); + } + + return services; + } + + /// + /// Decorates all keyed registrations of matching + /// by wrapping each resolved instance with an automatically constructed instance of . + /// Pass to match all keys. + /// + /// The service type to decorate. + /// The decorator type. Its constructor must accept a parameter for the inner service; additional parameters are resolved from the service provider. + /// The service collection containing the registrations to decorate. + /// The key to match, or for all keys. + /// The same instance for chaining. + /// + /// + /// This overload uses + /// to construct the decorator, passing the inner service instance as an explicit argument. Any remaining + /// constructor parameters are resolved from the . + /// + /// + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey) + where TService : class + where TDecorator : class, TService + { + // Delegate to the keyed factory overload so key matching behavior is shared in one place. + return services.DecorateKeyed( + serviceKey, + (sp, key, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, inner)); + } +#endif + + /// + /// Decorates all closed-generic registrations whose generic type definition matches + /// by wrapping each resolved instance with an instance of . + /// + /// The service collection containing the registrations to decorate. + /// The open generic service type (e.g., typeof(IRepository<>)). + /// The open generic decorator type that wraps the original service. + /// The same instance for chaining. + /// + /// + /// This method scans all service descriptors for closed-generic types whose generic type definition matches + /// . For each match, it closes over the + /// same type arguments and replaces the descriptor with a factory that resolves the original service and passes + /// it to . + /// The replacement descriptor preserves the original registration's lifetime. + /// + /// + /// Open-generic descriptors (those registered with an open typeof(T<>) service type) are skipped + /// because the DI container resolves them lazily and replacing the implementation type would cause recursive + /// resolution. Both keyed and non-keyed closed-generic registrations are decorated when running on .NET 8+. + /// + /// + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + + if (openServiceType is null) + throw new global::System.ArgumentNullException(nameof(openServiceType)); + + if (openDecoratorType is null) + throw new global::System.ArgumentNullException(nameof(openDecoratorType)); + + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + int count = services.Count; + for (int i = 0; i < count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + // Only closed generic registrations can be rewritten safely. + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + // Close the open decorator over the current service type arguments. + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + // Skip incompatible generic argument combinations. + continue; + } + +#if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var lifetime = descriptor.Lifetime; + + if (descriptor.KeyedImplementationInstance is object keyedInstance) + { + // For implementation instances, the key selects the descriptor; the instance is already materialized. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, keyedInstance), + lifetime: lifetime); + + continue; + } + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + { + // Factory registrations are key-aware by contract, so the runtime key must flow to the inner factory. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, keyedFactory(sp, key)), + lifetime: lifetime); + + continue; + } + + var keyedImplementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + + // For type-based activation, MS.DI uses the key for descriptor selection, not as an implicit ctor argument. + // Passing key into CreateInstance could bind to arbitrary object parameters and change activation semantics. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + serviceKey: originalKey, + factory: (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, keyedImplementationType)), + lifetime: lifetime); + + continue; + } +#endif + + var nonKeyedLifetime = descriptor.Lifetime; + + if (descriptor.ImplementationInstance is object instance) + { + // Wrap non-keyed implementation instances directly. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, instance), + lifetime: nonKeyedLifetime); + + continue; + } + + if (descriptor.ImplementationFactory is global::System.Func factory) + { + // Preserve factory semantics by resolving inner through the original factory. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, factory(sp)), + lifetime: nonKeyedLifetime); + + continue; + } + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + + // Type registrations are activated and then wrapped by the closed decorator. + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType: serviceType, + factory: sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType)), + lifetime: nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { +#if NET8_0_OR_GREATER + // On .NET 8+, keyed metadata is available directly on the descriptor. + return descriptor.IsKeyedService; +#else + // Keyed services are unavailable on older target frameworks. + return false; +#endif + } + + } +} diff --git a/src/Injectio.Shared/Infrastructure/EmbeddedAttribute.cs b/src/Injectio.Shared/Infrastructure/EmbeddedAttribute.cs new file mode 100644 index 0000000..bdc2526 --- /dev/null +++ b/src/Injectio.Shared/Infrastructure/EmbeddedAttribute.cs @@ -0,0 +1,8 @@ +namespace Microsoft.CodeAnalysis; + +/// +/// A special attribute recognized by Roslyn, that marks a type as "embedded", meaning it won't ever be visible from other assemblies. +/// +[global::System.AttributeUsage(global::System.AttributeTargets.All)] +[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +internal sealed partial class EmbeddedAttribute : global::System.Attribute; diff --git a/src/Injectio.Shared/Injectio.Shared.csproj b/src/Injectio.Shared/Injectio.Shared.csproj new file mode 100644 index 0000000..5ff6295 --- /dev/null +++ b/src/Injectio.Shared/Injectio.Shared.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0;net8.0;net9.0;net10.0 + Injectio + false + disable + + + + + diff --git a/tests/Injectio.Acceptance.Tests/DecoratorTests.cs b/tests/Injectio.Acceptance.Tests/DecoratorTests.cs new file mode 100644 index 0000000..c16818d --- /dev/null +++ b/tests/Injectio.Acceptance.Tests/DecoratorTests.cs @@ -0,0 +1,48 @@ +using AwesomeAssertions; + +using Injectio.Acceptance.Tests.Services; + +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Acceptance.Tests; + +[Collection(DependencyInjectionCollection.CollectionName)] +public class DecoratorTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture) +{ + [Fact] + public void ShouldResolveChainedDecoratorsInOrder() + { + var greeter = Services.GetRequiredService(); + + greeter.Should().BeOfType(); + greeter.Greet().Should().Be("caching(logging(base))"); + + var caching = (CachingGreeter)greeter; + caching.Inner.Should().BeOfType(); + + var logging = (LoggingGreeter)caching.Inner; + logging.Inner.Should().BeOfType(); + } + + [Fact] + public void ShouldDecorateClosedGenericViaOpenDecorator() + { + var repo = Services.GetRequiredService>(); + + repo.Should().BeOfType>(); + repo.Describe().Should().Be("logging(repo)"); + } + + [Fact] + public void ShouldDecorateEveryKeyedVariantWhenAnyKey() + { + var alpha = Services.GetRequiredKeyedService("alpha"); + var beta = Services.GetRequiredKeyedService("beta"); + + alpha.Should().BeOfType(); + alpha.Name.Should().Be("wrapped(alpha)"); + + beta.Should().BeOfType(); + beta.Name.Should().Be("wrapped(beta)"); + } +} diff --git a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj index 2bff6fc..af3eee1 100644 --- a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj +++ b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj @@ -24,7 +24,6 @@ - diff --git a/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs b/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs new file mode 100644 index 0000000..a428490 --- /dev/null +++ b/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs @@ -0,0 +1,82 @@ +using Injectio.Attributes; + +namespace Injectio.Acceptance.Tests.Services; + +public interface IGreeter +{ + string Greet(); +} + +[RegisterSingleton] +public class BaseGreeter : IGreeter +{ + public string Greet() => "base"; +} + +[RegisterDecorator(Order = 1)] +public class LoggingGreeter : IGreeter +{ + public IGreeter Inner { get; } + + public LoggingGreeter(IGreeter inner) => Inner = inner; + + public string Greet() => $"logging({Inner.Greet()})"; +} + +[RegisterDecorator(Order = 2)] +public class CachingGreeter : IGreeter +{ + public IGreeter Inner { get; } + + public CachingGreeter(IGreeter inner) => Inner = inner; + + public string Greet() => $"caching({Inner.Greet()})"; +} + +public interface IRepo +{ + string Describe(); +} + +[RegisterSingleton, StringRepo>] +public class StringRepo : IRepo +{ + public string Describe() => "repo"; +} + +[RegisterDecorator(ServiceType = typeof(IRepo<>))] +public class LoggingRepo : IRepo +{ + public IRepo Inner { get; } + + public LoggingRepo(IRepo inner) => Inner = inner; + + public string Describe() => $"logging({Inner.Describe()})"; +} + +public interface IKeyedThing +{ + string Name { get; } +} + +[RegisterSingleton(ServiceKey = "alpha")] +public class AlphaThing : IKeyedThing +{ + public string Name => "alpha"; +} + +[RegisterSingleton(ServiceKey = "beta")] +public class BetaThing : IKeyedThing +{ + public string Name => "beta"; +} + +[RegisterDecorator(AnyKey = true)] +public class WrappedThing : IKeyedThing +{ + public IKeyedThing Inner { get; } + + public WrappedThing(IKeyedThing inner) => Inner = inner; + + public string Name => $"wrapped({Inner.Name})"; +} diff --git a/tests/Injectio.Shared.Tests/DecorationExtensionsTests.cs b/tests/Injectio.Shared.Tests/DecorationExtensionsTests.cs new file mode 100644 index 0000000..3cab159 --- /dev/null +++ b/tests/Injectio.Shared.Tests/DecorationExtensionsTests.cs @@ -0,0 +1,374 @@ +using AwesomeAssertions; + +using Injectio.Extensions; + +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Shared.Tests; + +public interface IGreeter +{ + string Greet(); +} + +public class SimpleGreeter : IGreeter +{ + public string Greet() => "Hello"; +} + +public class DecoratedGreeter(IGreeter inner) : IGreeter +{ + public IGreeter Inner => inner; + public string Greet() => $"[Decorated] {inner.Greet()}"; +} + +public class AnotherDecoratedGreeter(IGreeter inner) : IGreeter +{ + public IGreeter Inner => inner; + public string Greet() => $"[Another] {inner.Greet()}"; +} + +public interface IRepository +{ + T? Get(int id); +} + +public class InMemoryRepository : IRepository +{ + public T? Get(int id) => default; +} + +public class CachingRepository(IRepository inner) : IRepository +{ + public IRepository Inner => inner; + public T? Get(int id) => inner.Get(id); +} + +public class DecorationExtensionsTests +{ + [Fact] + public void WhenDecorateWithFactoryThenResolvedServiceIsDecorated() + { + var services = new ServiceCollection(); + services.AddTransient(); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredService(); + + greeter.Should().BeOfType(); + greeter.Greet().Should().Be("[Decorated] Hello"); + } + + [Fact] + public void WhenDecorateWithFactoryThenLifetimeIsPreserved() + { + var services = new ServiceCollection(); + services.AddScoped(); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + services[0].Lifetime.Should().Be(ServiceLifetime.Scoped); + } + + [Fact] + public void WhenDecorateWithFactoryThenReturnsSameServiceCollection() + { + var services = new ServiceCollection(); + services.AddTransient(); + + var result = services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + result.Should().BeSameAs(services); + } + + [Fact] + public void WhenDecorateCalledMultipleTimesThenDecoratorsStack() + { + var services = new ServiceCollection(); + services.AddTransient(); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + services.Decorate((sp, inner) => new AnotherDecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredService(); + + greeter.Should().BeOfType(); + greeter.Greet().Should().Be("[Another] [Decorated] Hello"); + } + + [Fact] + public void WhenDecorateWithImplementationInstanceThenInstanceIsWrapped() + { + var services = new ServiceCollection(); + services.AddSingleton(new SimpleGreeter()); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredService(); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateWithFactoryRegistrationThenFactoryIsWrapped() + { + var services = new ServiceCollection(); + services.AddTransient(sp => new SimpleGreeter()); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredService(); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenNoRegistrationExistsThenDecorateIsNoOp() + { + var services = new ServiceCollection(); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + services.Should().BeEmpty(); + } + + [Fact] + public void WhenDecorateWithTypeParametersThenDecoratorIsResolved() + { + var services = new ServiceCollection(); + services.AddTransient(); + + services.Decorate(); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredService(); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedServiceWithFactoryThenKeyedServiceIsDecorated() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1"); + + services.DecorateKeyed("key1", (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedServiceThenNonKeyedIsSkipped() + { + var services = new ServiceCollection(); + services.AddTransient(); + services.AddKeyedTransient("key1"); + + services.DecorateKeyed("key1", (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + // Non-keyed should remain undecorated + var nonKeyed = provider.GetRequiredService(); + nonKeyed.Should().BeOfType(); + + // Keyed should be decorated + var keyed = provider.GetRequiredKeyedService("key1"); + keyed.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedWithDifferentKeyThenNotDecorated() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1"); + + services.DecorateKeyed("key2", (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedWithAnyKeyThenAllKeyedAreDecorated() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1"); + services.AddKeyedTransient("key2"); + + services.DecorateKeyed( + KeyedService.AnyKey, + (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + + provider.GetRequiredKeyedService("key1").Should().BeOfType(); + provider.GetRequiredKeyedService("key2").Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedWithTypeParametersThenDecoratorIsResolved() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1"); + + services.DecorateKeyed("key1"); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedThenLifetimeIsPreserved() + { + var services = new ServiceCollection(); + services.AddKeyedScoped("key1"); + + services.DecorateKeyed("key1", (sp, key, inner) => new DecoratedGreeter(inner)); + + services[0].Lifetime.Should().Be(ServiceLifetime.Scoped); + } + + [Fact] + public void WhenDecorateKeyedWithInstanceThenInstanceIsWrapped() + { + var services = new ServiceCollection(); + services.AddKeyedSingleton("key1", new SimpleGreeter()); + + services.DecorateKeyed("key1", (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateKeyedWithFactoryRegistrationThenFactoryIsWrapped() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1", (sp, key) => new SimpleGreeter()); + + services.DecorateKeyed("key1", (sp, key, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateOpenGenericThenClosedGenericsAreDecorated() + { + var services = new ServiceCollection(); + services.AddTransient, InMemoryRepository>(); + services.AddTransient, InMemoryRepository>(); + + services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + using var provider = services.BuildServiceProvider(); + + provider.GetRequiredService>().Should().BeOfType>(); + provider.GetRequiredService>().Should().BeOfType>(); + } + + [Fact] + public void WhenDecorateOpenGenericThenLifetimeIsPreserved() + { + var services = new ServiceCollection(); + services.AddScoped, InMemoryRepository>(); + + services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + services[0].Lifetime.Should().Be(ServiceLifetime.Scoped); + } + + [Fact] + public void WhenDecorateOpenGenericThenReturnsSameServiceCollection() + { + var services = new ServiceCollection(); + services.AddTransient, InMemoryRepository>(); + + var result = services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + result.Should().BeSameAs(services); + } + + [Fact] + public void WhenDecorateOpenGenericThenNonGenericRegistrationsAreSkipped() + { + var services = new ServiceCollection(); + services.AddTransient(); + services.AddTransient, InMemoryRepository>(); + + services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + using var provider = services.BuildServiceProvider(); + + provider.GetRequiredService().Should().BeOfType(); + provider.GetRequiredService>().Should().BeOfType>(); + } + + [Fact] + public void WhenDecorateOpenGenericThenOpenGenericDescriptorsAreSkipped() + { + var services = new ServiceCollection(); + services.AddTransient(typeof(IRepository<>), typeof(InMemoryRepository<>)); + + services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + // The open-generic descriptor should remain unchanged + services[0].ImplementationType.Should().Be(typeof(InMemoryRepository<>)); + } + + [Fact] + public void WhenDecorateOpenGenericWithKeyedServiceThenKeyedIsDecorated() + { + var services = new ServiceCollection(); + services.AddKeyedTransient, InMemoryRepository>("cache"); + + services.DecorateOpenGeneric(typeof(IRepository<>), typeof(CachingRepository<>)); + + using var provider = services.BuildServiceProvider(); + var repo = provider.GetRequiredKeyedService>("cache"); + + repo.Should().BeOfType>(); + } + + [Fact] + public void WhenDecorateSkipsKeyedRegistrations() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("key1"); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + using var provider = services.BuildServiceProvider(); + var greeter = provider.GetRequiredKeyedService("key1"); + + // Keyed registration should not be affected by non-keyed Decorate + greeter.Should().BeOfType(); + } + + [Fact] + public void WhenDecorateWithInstanceRegistrationThenSingletonLifetimePreserved() + { + var services = new ServiceCollection(); + services.AddSingleton(new SimpleGreeter()); + + services.Decorate((sp, inner) => new DecoratedGreeter(inner)); + + services[0].Lifetime.Should().Be(ServiceLifetime.Singleton); + } +} diff --git a/tests/Injectio.Shared.Tests/Injectio.Shared.Tests.csproj b/tests/Injectio.Shared.Tests/Injectio.Shared.Tests.csproj new file mode 100644 index 0000000..793b71d --- /dev/null +++ b/tests/Injectio.Shared.Tests/Injectio.Shared.Tests.csproj @@ -0,0 +1,29 @@ + + + + net10.0 + Exe + enable + enable + latest + false + true + + + + + + + + + + + + + + + + + + + diff --git a/tests/Injectio.Tests.Console/Injectio.Tests.Console.csproj b/tests/Injectio.Tests.Console/Injectio.Tests.Console.csproj index c3fb328..b354138 100644 --- a/tests/Injectio.Tests.Console/Injectio.Tests.Console.csproj +++ b/tests/Injectio.Tests.Console/Injectio.Tests.Console.csproj @@ -24,7 +24,6 @@ - diff --git a/tests/Injectio.Tests.Library/Injectio.Tests.Library.csproj b/tests/Injectio.Tests.Library/Injectio.Tests.Library.csproj index 04e3105..f105fe5 100644 --- a/tests/Injectio.Tests.Library/Injectio.Tests.Library.csproj +++ b/tests/Injectio.Tests.Library/Injectio.Tests.Library.csproj @@ -16,7 +16,6 @@ - diff --git a/tests/Injectio.Tests/Injectio.Tests.csproj b/tests/Injectio.Tests/Injectio.Tests.csproj index cac5916..46d18ed 100644 --- a/tests/Injectio.Tests/Injectio.Tests.csproj +++ b/tests/Injectio.Tests/Injectio.Tests.csproj @@ -23,7 +23,6 @@ - Analyzer true diff --git a/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs b/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs new file mode 100644 index 0000000..0e9d21e --- /dev/null +++ b/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs @@ -0,0 +1,487 @@ +using System; +using System.Collections.Immutable; +using System.Linq; +using System.Threading.Tasks; + +using AwesomeAssertions; + +using Injectio.Generators; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.Extensions.DependencyInjection; + +using VerifyXunit; + +using Xunit; + +namespace Injectio.Tests; + +public class ServiceRegistrationDecoratorTests +{ + [Fact] + public Task GenerateDecoratorSimple() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(ServiceType = typeof(IService))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorGenericAttribute() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorChainedOrder() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Order = 2)] + public class CachingDecorator : IService + { + public CachingDecorator(IService inner) { } + } + + [RegisterDecorator(Order = 1)] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorKeyed() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(ServiceKey = "Alpha")] + public class Service : IService { } + + [RegisterDecorator(ServiceKey = "Alpha")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorAnyKey() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(ServiceKey = "Alpha")] + public class ServiceA : IService { } + + [RegisterSingleton(ServiceKey = "Beta")] + public class ServiceB : IService { } + + [RegisterDecorator(AnyKey = true)] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorFactory() + { + const string source = """ + using System; + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = nameof(Create))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + + public static IService Create(IServiceProvider serviceProvider, IService inner) + => new LoggingDecorator(inner); + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorOpenGeneric() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IRepo { } + + [RegisterSingleton(ServiceType = typeof(IRepo<>), ImplementationType = typeof(Repo<>))] + public class Repo : IRepo { } + + [RegisterDecorator(ServiceType = typeof(IRepo<>))] + public class LoggingRepo : IRepo + { + public LoggingRepo(IRepo inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorTags() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(Tags = "FrontEnd")] + public class Service : IService { } + + [RegisterDecorator(Tags = "FrontEnd")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + // ------- Diagnostics ------- + + [Fact] + public async Task DiagnoseDecoratorDoesNotImplementService() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + public interface IOther { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(ServiceType = typeof(IService))] + public class BadDecorator : IOther { } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0010"); + } + + [Fact] + public async Task DiagnoseDecoratorMissingServiceType() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0011"); + } + + [Fact] + public async Task DiagnoseDecoratorFactoryNotFound() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = "Missing")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0013"); + } + + [Fact] + public async Task DiagnoseDecoratorFactoryInvalidSignature() + { + const string source = """ + using System; + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = nameof(Create))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + + public IService Create() => this; + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0014"); + } + + [Fact] + public async Task DiagnoseDecoratorOpenGenericKeyed() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IRepo { } + + [RegisterSingleton(ServiceType = typeof(IRepo<>), ImplementationType = typeof(Repo<>))] + public class Repo : IRepo { } + + [RegisterDecorator(ServiceType = typeof(IRepo<>), ServiceKey = "X")] + public class LoggingRepo : IRepo + { + public LoggingRepo(IRepo inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0015"); + } + + [Fact] + public async Task DiagnoseDecoratorTargetNotRegistered() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0016"); + } + + [Fact] + public async Task NoDiagnosticForValidDecorator() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().BeEmpty(); + } + + + private static Task Verify(string source) + { + var output = GetGeneratedOutput(source); + + return Verifier + .Verify(output) + .UseDirectory("Snapshots") + .ScrubLinesContaining("GeneratedCodeAttribute"); + } + + private static string GetGeneratedOutput(string source) + where T : IIncrementalGenerator, new() + { + var parseOptions = CSharpParseOptions.Default.WithPreprocessorSymbols( + "NET7_0_OR_GREATER", + "NET8_0_OR_GREATER", + "NET9_0_OR_GREATER", + "NET10_0_OR_GREATER"); + + var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions); + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat( + [ + MetadataReference.CreateFromFile(typeof(T).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + ]); + + var compilation = CSharpCompilation.Create( + "Test.Generator", + [syntaxTree], + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var originalTreeCount = compilation.SyntaxTrees.Length; + var generator = new T(); + + var driver = CSharpGeneratorDriver.Create( + generators: [generator.AsSourceGenerator()], + parseOptions: parseOptions); + + driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); + + var generated = outputCompilation.SyntaxTrees + .Skip(originalTreeCount) + .FirstOrDefault(t => Path.GetFileName(t.FilePath) == "Injectio.g.cs"); + + return generated?.ToString() ?? string.Empty; + } + + private static async Task> GetDiagnosticsAsync(string source) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var attributeSource = LoadEmbeddedAttributeSource(); + + var attributeTree = CSharpSyntaxTree.ParseText(attributeSource, + CSharpParseOptions.Default.WithPreprocessorSymbols("NET7_0_OR_GREATER"), + path: "Injectio.Attributes.cs"); + + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat( + [ + MetadataReference.CreateFromFile(typeof(ServiceRegistrationGenerator).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + ]); + + var compilation = CSharpCompilation.Create( + "Test.Diagnostics", + [syntaxTree, attributeTree], + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var analyzer = new ServiceRegistrationAnalyzer(); + var compilationWithAnalyzers = compilation.WithAnalyzers([analyzer]); + var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync(); + + return diagnostics + .Where(d => d.Id.StartsWith("INJ")) + .ToImmutableArray(); + } + + private static string LoadEmbeddedAttributeSource() + { + var assembly = typeof(ServiceRegistrationGenerator).Assembly; + using var stream = assembly.GetManifestResourceStream("Injectio.Generators.Embedded.Attributes.cs") + ?? throw new InvalidOperationException("Embedded Attributes.cs resource not found"); + + using var reader = new System.IO.StreamReader(stream); + var source = reader.ReadToEnd(); + + // Strip markers that prevent attributes from resolving in a standalone compilation + return source + .Replace("// ", "") + .Replace("[global::Microsoft.CodeAnalysis.Embedded]", ""); + } +} diff --git a/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs index b3632d8..d563ca6 100644 --- a/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs +++ b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs @@ -5,7 +5,6 @@ using AwesomeAssertions; -using Injectio.Attributes; using Injectio.Generators; using Microsoft.CodeAnalysis; @@ -420,24 +419,29 @@ public static void Register(IServiceCollection services, IEnumerable tag private static async Task> GetDiagnosticsAsync(string source) { var syntaxTree = CSharpSyntaxTree.ParseText(source); + var attributeSource = LoadEmbeddedAttributeSource(); + + var attributeTree = CSharpSyntaxTree.ParseText(attributeSource, + CSharpParseOptions.Default.WithPreprocessorSymbols("NET7_0_OR_GREATER"), + path: "Injectio.Attributes.cs"); + var references = AppDomain.CurrentDomain.GetAssemblies() .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) - .Concat(new[] - { + .Concat( + [ MetadataReference.CreateFromFile(typeof(ServiceRegistrationGenerator).Assembly.Location), - MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location), MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), - }); + ]); var compilation = CSharpCompilation.Create( "Test.Diagnostics", - new[] { syntaxTree }, + [syntaxTree, attributeTree], references, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); var analyzer = new ServiceRegistrationAnalyzer(); - var compilationWithAnalyzers = compilation.WithAnalyzers(ImmutableArray.Create(analyzer)); + var compilationWithAnalyzers = compilation.WithAnalyzers([analyzer]); var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync(); // return only Injectio diagnostics @@ -445,4 +449,19 @@ private static async Task> GetDiagnosticsAsync(string .Where(d => d.Id.StartsWith("INJ")) .ToImmutableArray(); } + + private static string LoadEmbeddedAttributeSource() + { + var assembly = typeof(ServiceRegistrationGenerator).Assembly; + using var stream = assembly.GetManifestResourceStream("Injectio.Generators.Embedded.Attributes.cs") + ?? throw new InvalidOperationException("Embedded Attributes.cs resource not found"); + + using var reader = new System.IO.StreamReader(stream); + var source = reader.ReadToEnd(); + + // Strip markers that prevent attributes from resolving in a standalone compilation + return source + .Replace("// ", "") + .Replace("[global::Microsoft.CodeAnalysis.Embedded]", ""); + } } diff --git a/tests/Injectio.Tests/ServiceRegistrationGeneratorTests.cs b/tests/Injectio.Tests/ServiceRegistrationGeneratorTests.cs index cb9bcb8..5d1dff2 100644 --- a/tests/Injectio.Tests/ServiceRegistrationGeneratorTests.cs +++ b/tests/Injectio.Tests/ServiceRegistrationGeneratorTests.cs @@ -5,7 +5,6 @@ using AwesomeAssertions; -using Injectio.Attributes; using Injectio.Generators; using Microsoft.CodeAnalysis; @@ -37,12 +36,7 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -62,12 +56,7 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -87,12 +76,7 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -112,12 +96,7 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -145,12 +124,7 @@ public class ServiceMultiple : IService1, IService2 """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -170,12 +144,7 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -197,12 +166,7 @@ public class SingletonService : IService1, IService2 """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -210,6 +174,7 @@ public Task GenerateRegisterServicesInstance() { const string source = """ + using System; using Injectio.Attributes; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -235,12 +200,7 @@ public void Register(IServiceCollection services) """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -273,12 +233,7 @@ public static void Register(IServiceCollection services) """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -311,12 +266,7 @@ public static void Register(IServiceCollection services, string test) """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -349,12 +299,7 @@ public static void Register(string test) """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -391,12 +336,7 @@ public static IFactoryService1 ServiceFactory(IServiceProvider serviceProvider) """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -417,12 +357,7 @@ public class OpenGeneric : IOpenGeneric """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -443,12 +378,7 @@ public class OpenGeneric : IOpenGeneric """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -469,12 +399,7 @@ public class Service : IClosedGeneric """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } [Fact] @@ -495,15 +420,9 @@ public class ServiceTag : IServiceTag """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } -#if NET7_0_OR_GREATER [Fact] public Task GenerateRegisterSingletonGeneric() { @@ -521,26 +440,24 @@ public class SingletonService : IService """; - var output = GetGeneratedOutput(source); - - return Verifier - .Verify(output) - .UseDirectory("Snapshots") - .ScrubLinesContaining("GeneratedCodeAttribute"); + return Verify(source); } -#endif [Fact] public Task GenerateRegisterSingletonServiceKeys() { const string source = """ + using System; using Injectio.Attributes; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; namespace Injectio.Sample; + public interface IServiceKeyed + { } + [RegisterSingleton(ServiceType = typeof(IServiceKeyed), ServiceKey = "Alpha")] public class ServiceAlphaKeyed : IServiceKeyed { } @@ -585,6 +502,12 @@ public enum ServiceType """; + return Verify(source); + } + + + private static Task Verify(string source) + { var output = GetGeneratedOutput(source); return Verifier @@ -593,35 +516,44 @@ public enum ServiceType .ScrubLinesContaining("GeneratedCodeAttribute"); } - private static string GetGeneratedOutput(string source) where T : IIncrementalGenerator, new() { - var syntaxTree = CSharpSyntaxTree.ParseText(source); + var parseOptions = CSharpParseOptions.Default.WithPreprocessorSymbols( + "NET7_0_OR_GREATER", + "NET8_0_OR_GREATER", + "NET9_0_OR_GREATER", + "NET10_0_OR_GREATER"); + + var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions); var references = AppDomain.CurrentDomain.GetAssemblies() .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) - .Concat(new[] - { + .Concat( + [ MetadataReference.CreateFromFile(typeof(T).Assembly.Location), - MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location), MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), - }); + ]); var compilation = CSharpCompilation.Create( "Test.Generator", - new[] { syntaxTree }, + [syntaxTree], references, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); var originalTreeCount = compilation.SyntaxTrees.Length; var generator = new T(); - var driver = CSharpGeneratorDriver.Create(generator); - driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics); + var driver = CSharpGeneratorDriver.Create( + generators: [generator.AsSourceGenerator()], + parseOptions: parseOptions); + + driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var trees = outputCompilation.SyntaxTrees.ToList(); + var generated = outputCompilation.SyntaxTrees + .Skip(originalTreeCount) + .FirstOrDefault(t => Path.GetFileName(t.FilePath) == "Injectio.g.cs"); - return trees.Count != originalTreeCount ? trees[^1].ToString() : string.Empty; + return generated?.ToString() ?? string.Empty; } } diff --git a/tests/Injectio.Tests/ServiceRegistrationWriterTests.cs b/tests/Injectio.Tests/ServiceRegistrationWriterTests.cs index 38476a3..d0457fb 100644 --- a/tests/Injectio.Tests/ServiceRegistrationWriterTests.cs +++ b/tests/Injectio.Tests/ServiceRegistrationWriterTests.cs @@ -1,12 +1,5 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; - using Injectio.Generators; - -using VerifyXunit; - -using Xunit; +using Injectio.Generators.Models; namespace Injectio.Tests; diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt new file mode 100644 index 0000000..4dd6240 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt @@ -0,0 +1,40 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Beta") + ); + + global::Injectio.Extensions.DecorationExtensions.DecorateKeyed( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey, + static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt new file mode 100644 index 0000000..5601d7e --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt @@ -0,0 +1,46 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt new file mode 100644 index 0000000..a15f95f --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt @@ -0,0 +1,41 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Injectio.Sample.LoggingDecorator.Create(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt new file mode 100644 index 0000000..ffa218f --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt @@ -0,0 +1,41 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt new file mode 100644 index 0000000..363e4a8 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt @@ -0,0 +1,35 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") + ); + + global::Injectio.Extensions.DecorationExtensions.DecorateKeyed( + serviceCollection, + "Alpha", + static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt new file mode 100644 index 0000000..d17363f --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt @@ -0,0 +1,38 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + typeof(global::Injectio.Sample.IRepo<>), + typeof(global::Injectio.Sample.Repo<>) + ) + ); + + global::Injectio.Extensions.DecorationExtensions.DecorateOpenGeneric( + serviceCollection, + typeof(global::Injectio.Sample.IRepo<>), + typeof(global::Injectio.Sample.LoggingRepo<>) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt new file mode 100644 index 0000000..ffa218f --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt @@ -0,0 +1,41 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt new file mode 100644 index 0000000..c308ccf --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt @@ -0,0 +1,49 @@ +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + if (tagSet.Count == 0 || tagSet.Intersect(new[] { "FrontEnd" }).Any()) + { + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + } + + if (tagSet.Count == 0 || tagSet.Intersect(new[] { "FrontEnd" }).Any()) + { + global::Injectio.Extensions.DecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + } + + return serviceCollection; + } + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonServiceKeys.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonServiceKeys.verified.txt index 8137c4c..72b29ac 100644 --- a/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonServiceKeys.verified.txt +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationGeneratorTests.GenerateRegisterSingletonServiceKeys.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace Microsoft.Extensions.DependencyInjection @@ -20,32 +20,32 @@ namespace Microsoft.Extensions.DependencyInjection global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Beta") + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Beta") ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton(Injectio.Sample.ServiceType.Alpha) + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton(Injectio.Sample.ServiceType.Alpha) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton(Injectio.Sample.ServiceType.Beta) + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton(Injectio.Sample.ServiceType.Beta) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Charlie", global::Injectio.Sample.ServiceFactoryKeyed.ServiceFactory) + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Charlie", global::Injectio.Sample.ServiceFactoryKeyed.ServiceFactory) ); global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( serviceCollection, - global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Delta", global::Injectio.Sample.ServiceFactoryKeyed.ServiceFactory) + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Delta", global::Injectio.Sample.ServiceFactoryKeyed.ServiceFactory) ); return serviceCollection;