diff --git a/Directory.Packages.props b/Directory.Packages.props
index 05e11c0..221d1e3 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -8,11 +8,11 @@
-
+
-
\ No newline at end of file
+
diff --git a/Injectio.slnx b/Injectio.slnx
index 6b7aea5..babb955 100644
--- a/Injectio.slnx
+++ b/Injectio.slnx
@@ -8,10 +8,11 @@
+
-
+
diff --git a/README.md b/README.md
index 35978a6..b1970f5 100644
--- a/README.md
+++ b/README.md
@@ -17,21 +17,20 @@ Source generator that helps register attribute marked services in the dependency
- Module method registration
- Duplicate Strategy - Skip,Replace,Append
- Registration Strategy - Self, Implemented Interfaces, Self With Interfaces
+- Decorator registration (`RegisterDecorator`) — no runtime dependencies
### Usage
+#### Requirements
+
+Requires Roslyn 4.14 or later. This means Visual Studio 2022 version 17.14+ or Visual Studio 2026+, a current Rider release, and the .NET 9.0.300 SDK or newer. Older toolchains will not load the analyzer.
+
#### Add package
Add the nuget package project to your projects.
`dotnet add package Injectio`
-Prevent dependances from including Injectio
-
-```xml
-
-```
-
### Registration Attributes
Place registration attribute on class. The class will be discovered and registered.
@@ -40,6 +39,7 @@ Place registration attribute on class. The class will be discovered and registe
- `[RegisterScoped]` Marks the class as a scoped service
- `[RegisterTransient]` Marks the class as a transient service
- `[RegisterServices]` Marks the method to be called to register services
+- `[RegisterDecorator]` Marks the class as a decorator around an existing service
#### Attribute Properties
@@ -217,6 +217,119 @@ public class ServiceFactoryKeyed : IServiceKeyed
}
```
+#### Decorators
+
+Use the `RegisterDecorator` attribute to wrap an existing service registration without adding
+any runtime dependencies. The generator emits all decoration helpers directly into the
+consumer assembly.
+
+Decorators inherit the lifetime of the service they decorate. Apply multiple decorators by
+ordering them with the `Order` property — lower values are innermost (applied first), higher
+values are outermost (applied last).
+
+```c#
+public interface IService { }
+
+[RegisterSingleton]
+public class Service : IService { }
+
+[RegisterDecorator(Order = 1)]
+public class LoggingDecorator : IService
+{
+ public LoggingDecorator(IService inner) { }
+}
+
+[RegisterDecorator(Order = 2)]
+public class CachingDecorator : IService
+{
+ public CachingDecorator(IService inner) { }
+}
+```
+
+Resolution order for the sample above: `CachingDecorator → LoggingDecorator → Service`.
+
+##### Decorator Attribute Properties
+
+| Property | Description |
+|--------------------|------------------------------------------------------------------------------------------------|
+| ServiceType | The type of service to decorate. Required unless the generic attribute form is used. |
+| ImplementationType | The decorator type. If not set, the class the attribute is on will be used. |
+| ServiceKey | Decorate a specific keyed registration. Requires .NET 8+ Microsoft.Extensions.DependencyInjection. |
+| AnyKey | When `true`, decorate every keyed registration of `ServiceType` regardless of its key. |
+| Factory | Name of a static factory method that builds the decorator. |
+| Order | Ordering within the decoration chain. Lower = innermost. |
+| Tags | Comma/semicolon-delimited list of registration tags. |
+
+##### Keyed decoration
+
+Decorate a single keyed variant, or use `AnyKey` to decorate them all:
+
+```c#
+[RegisterSingleton(ServiceKey = "alpha")]
+public class AlphaService : IService { }
+
+[RegisterDecorator(AnyKey = true)]
+public class LoggingDecorator : IService
+{
+ public LoggingDecorator(IService inner) { }
+}
+```
+
+##### Factory-built decorators
+
+Provide a static factory on the decorator class for complex construction:
+
+```c#
+[RegisterDecorator(Factory = nameof(Create))]
+public class LoggingDecorator : IService
+{
+ public LoggingDecorator(IService inner) { }
+
+ public static IService Create(IServiceProvider serviceProvider, IService inner)
+ => new LoggingDecorator(inner);
+}
+```
+
+For keyed decorators the factory takes an additional `object?` parameter for the key:
+
+```c#
+public static IService Create(IServiceProvider serviceProvider, object? serviceKey, IService inner)
+ => new LoggingDecorator(inner);
+```
+
+##### Open-generic decoration
+
+Open-generic decorators apply to every closed registration of the matching service type.
+The generator supports decorating closed-generic registrations with an open-generic decorator
+class; purely open-generic implementation registrations (e.g. `(IRepo<>, Repo<>)`) are not
+decorated at runtime due to a Microsoft.Extensions.DependencyInjection limitation on factory
+registrations for open generic service types.
+
+```c#
+public interface IRepo { }
+
+[RegisterSingleton, StringRepo>]
+public class StringRepo : IRepo { }
+
+[RegisterDecorator(ServiceType = typeof(IRepo<>))]
+public class LoggingRepo : IRepo
+{
+ public LoggingRepo(IRepo inner) { }
+}
+```
+
+##### Tags
+
+Decorators support the same tag-filtering as registrations:
+
+```c#
+[RegisterDecorator(Tags = "FrontEnd")]
+public class FrontEndLoggingDecorator : IService
+{
+ public FrontEndLoggingDecorator(IService inner) { }
+}
+```
+
#### Register Method
When the service registration is complex, use the `RegisterServices` attribute on a method that has a parameter of `IServiceCollection` or `ServiceCollection`
diff --git a/src/Injectio.Attributes/DuplicateStrategy.cs b/src/Injectio.Attributes/DuplicateStrategy.cs
deleted file mode 100644
index a383547..0000000
--- a/src/Injectio.Attributes/DuplicateStrategy.cs
+++ /dev/null
@@ -1,20 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Service registration duplicate strategies
-///
-public enum DuplicateStrategy
-{
- ///
- /// Skips registrations for services that already exists.
- ///
- Skip = 0,
- ///
- /// Replaces existing service registrations.
- ///
- Replace = 1,
- ///
- /// Appends a new registration for existing services.
- ///
- Append = 2
-}
diff --git a/src/Injectio.Attributes/Injectio.Attributes.csproj b/src/Injectio.Attributes/Injectio.Attributes.csproj
deleted file mode 100644
index b39d861..0000000
--- a/src/Injectio.Attributes/Injectio.Attributes.csproj
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
- netstandard2.0;net8.0;net9.0;net10.0
- latest
- Injectio
-
-
-
-
-
-
-
-
-
-
-
diff --git a/src/Injectio.Attributes/RegisterAttribute.cs b/src/Injectio.Attributes/RegisterAttribute.cs
deleted file mode 100644
index d590536..0000000
--- a/src/Injectio.Attributes/RegisterAttribute.cs
+++ /dev/null
@@ -1,71 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Attribute to indicate the target class should be register for dependency injection
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public abstract class RegisterAttribute : Attribute
-{
- ///
- /// Initializes a new instance of the class.
- ///
- protected RegisterAttribute()
- {
- Duplicate = DuplicateStrategy.Skip;
- Registration = RegistrationStrategy.SelfWithInterfaces;
- }
-
-
- ///
- /// The that implements the service. If not set, the class the interface is on will be used.
- ///
- ///
- public Type? ImplementationType { get; set; }
-
- ///
- /// The of the service
- ///
- ///
- public Type? ServiceType { get; set; }
-
- ///
- /// Gets or sets the key of the service.
- ///
- /// The service key.
- ///
- public object? ServiceKey { get; set; }
-
- ///
- /// Name of a factory method to create new instances of the service implementation
- ///
- ///
- /// The method signature must be Func{IServiceProvider, object} and must be static
- ///
- ///
- public string? Factory { get; set; }
-
- ///
- /// Gets or sets the duplicate.
- ///
- ///
- /// The duplicate.
- ///
- public DuplicateStrategy Duplicate { get; set; }
-
- ///
- /// Gets or sets the registration.
- ///
- ///
- /// The registration.
- ///
- public RegistrationStrategy Registration { get; set; }
-
- ///
- /// Gets or sets the comma delimited list of service registration tags.
- ///
- ///
- /// The comma delimited list of service registration tags.
- ///
- public string? Tags { get; set; }
-}
diff --git a/src/Injectio.Attributes/RegisterScopedAttribute.cs b/src/Injectio.Attributes/RegisterScopedAttribute.cs
deleted file mode 100644
index 0f3cb00..0000000
--- a/src/Injectio.Attributes/RegisterScopedAttribute.cs
+++ /dev/null
@@ -1,69 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a scoped service
-///
-/// Register the class as an implementation for IService
-///
-/// [RegisterScoped]
-/// public class ScopedService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterScopedAttribute : RegisterAttribute
-{
-}
-
-#if NET7_0_OR_GREATER
-///
-/// Attribute to indicate the target class should be register for dependency injection as a scoped service
-///
-/// The type of the service to add.
-/// Register the class as an implementation for IService
-///
-/// [RegisterScoped<IService>]
-/// public class ScopedService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterScopedAttribute : RegisterScopedAttribute
- where TService : class
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterScopedAttribute()
- {
- ServiceType = typeof(TService);
- }
-}
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a scoped service
-///
-/// The type of the service to add.
-/// The type of the implementation to use.
-/// Register the ScopedService class as an implementation for IService
-///
-/// [RegisterScoped<IService, ScopedService>]
-/// public class ScopedService: IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterScopedAttribute : RegisterScopedAttribute
- where TService : class
- where TImplementation : class, TService
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterScopedAttribute()
- {
- ServiceType = typeof(TService);
- ImplementationType = typeof(TImplementation);
- }
-}
-#endif
diff --git a/src/Injectio.Attributes/RegisterServicesAttribute.cs b/src/Injectio.Attributes/RegisterServicesAttribute.cs
deleted file mode 100644
index b2367bf..0000000
--- a/src/Injectio.Attributes/RegisterServicesAttribute.cs
+++ /dev/null
@@ -1,21 +0,0 @@
-namespace Injectio.Attributes;
-
-/// Attribute to indicate the method should be called to register services
-/// use the RegisterServices attribute
-///
-/// public class RegistrationModule
-/// {
-/// [RegisterServices]
-/// public static void Register(IServiceCollection services)
-/// {
-/// services.TryAddTransient<IModuleService, ModuleService>();
-/// }
-/// }
-///
-///
-[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterServicesAttribute : Attribute
-{
-
-}
diff --git a/src/Injectio.Attributes/RegisterSingletonAttribute.cs b/src/Injectio.Attributes/RegisterSingletonAttribute.cs
deleted file mode 100644
index 806ce7a..0000000
--- a/src/Injectio.Attributes/RegisterSingletonAttribute.cs
+++ /dev/null
@@ -1,69 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a singleton service
-///
-/// Register the class as an implementation for IService
-///
-/// [RegisterSingleton]
-/// public class SingletonService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterSingletonAttribute : RegisterAttribute
-{
-}
-
-#if NET7_0_OR_GREATER
-///
-/// Attribute to indicate the target class should be register for dependency injection as a singleton service
-///
-/// The type of the service to add.
-/// Register the class as an implementation for IService
-///
-/// [RegisterSingleton<IService>]
-/// public class SingletonService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterSingletonAttribute : RegisterSingletonAttribute
- where TService : class
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterSingletonAttribute()
- {
- ServiceType = typeof(TService);
- }
-}
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a singleton service
-///
-/// The type of the service to add.
-/// The type of the implementation to use.
-/// Register the SingletonService class as an implementation for IService
-///
-/// [RegisterSingleton<IService, SingletonService>]
-/// public class SingletonService: IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterSingletonAttribute : RegisterSingletonAttribute
- where TService : class
- where TImplementation : class, TService
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterSingletonAttribute()
- {
- ServiceType = typeof(TService);
- ImplementationType = typeof(TImplementation);
- }
-}
-#endif
diff --git a/src/Injectio.Attributes/RegisterTransientAttribute.cs b/src/Injectio.Attributes/RegisterTransientAttribute.cs
deleted file mode 100644
index c3564f6..0000000
--- a/src/Injectio.Attributes/RegisterTransientAttribute.cs
+++ /dev/null
@@ -1,69 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a transient service
-///
-/// Register the class as an implementation for IService
-///
-/// [RegisterTransient]
-/// public class TransientService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterTransientAttribute : RegisterAttribute
-{
-}
-
-#if NET7_0_OR_GREATER
-///
-/// Attribute to indicate the target class should be register for dependency injection as a transient service
-///
-/// The type of the service to add.
-/// Register the class as an implementation for IService
-///
-/// [RegisterTransient<IService>]
-/// public class TransientService : IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterTransientAttribute : RegisterTransientAttribute
- where TService : class
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterTransientAttribute()
- {
- ServiceType = typeof(TService);
- }
-}
-
-///
-/// Attribute to indicate the target class should be register for dependency injection as a transient service
-///
-/// The type of the service to add.
-/// The type of the implementation to use.
-/// Register the TransientService class as an implementation for IService
-///
-/// [RegisterTransient<IService, TransientService>]
-/// public class TransientService: IService { }
-///
-///
-[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
-[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
-public class RegisterTransientAttribute : RegisterTransientAttribute
- where TService : class
- where TImplementation : class, TService
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public RegisterTransientAttribute()
- {
- ServiceType = typeof(TService);
- ImplementationType = typeof(TImplementation);
- }
-}
-#endif
diff --git a/src/Injectio.Attributes/RegistrationStrategy.cs b/src/Injectio.Attributes/RegistrationStrategy.cs
deleted file mode 100644
index 3355135..0000000
--- a/src/Injectio.Attributes/RegistrationStrategy.cs
+++ /dev/null
@@ -1,26 +0,0 @@
-namespace Injectio.Attributes;
-
-///
-/// Register service type strategies
-///
-public enum RegistrationStrategy
-{
- ///
- /// Registers each matching concrete type as itself
- ///
- Self = 0,
- ///
- /// Registers each matching concrete type as all of its implemented interfaces.
- ///
- ImplementedInterfaces = 1,
- ///
- /// Registers each matching concrete type as all of its implemented interfaces and itself
- ///
- SelfWithInterfaces = 2,
- ///
- /// Registers each matching concrete type as all of its implemented interfaces and itself.
- /// For the interfaces a proxy-factory resolves the service from its type-name, so only one instance is created per lifetime
- ///
- /// For open-generic registrations, this behaves like
- SelfWithProxyFactory = 3
-}
diff --git a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md
index 65956d6..9ecbdb3 100644
--- a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md
+++ b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md
@@ -14,3 +14,10 @@ INJ0006 | Usage | Warning | Factory method has invalid signature
INJ0007 | Usage | Warning | Implementation does not implement service type
INJ0008 | Usage | Warning | Implementation type is abstract
INJ0009 | Usage | Warning | RegisterServices on non-static method in abstract class
+INJ0010 | Usage | Warning | Decorator does not implement service type
+INJ0011 | Usage | Warning | Decorator is missing service type
+INJ0012 | Usage | Warning | Decorator has no constructor accepting the inner service
+INJ0013 | Usage | Warning | Decorator factory method not found
+INJ0014 | Usage | Warning | Decorator factory method has invalid signature
+INJ0015 | Usage | Warning | Keyed decoration is not supported for open-generic services
+INJ0016 | Usage | Warning | Decorator target service is not registered in this compilation
diff --git a/src/Injectio.Generators/DiagnosticDescriptors.cs b/src/Injectio.Generators/DiagnosticDescriptors.cs
index 01c7157..f31812b 100644
--- a/src/Injectio.Generators/DiagnosticDescriptors.cs
+++ b/src/Injectio.Generators/DiagnosticDescriptors.cs
@@ -86,4 +86,69 @@ public static class DiagnosticDescriptors
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true
);
+
+ public static readonly DiagnosticDescriptor DecoratorDoesNotImplementService = new(
+ id: "INJ0010",
+ title: "Decorator does not implement service type",
+ messageFormat: "Decorator '{0}' does not implement or inherit from service type '{1}'",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorMissingServiceType = new(
+ id: "INJ0011",
+ title: "Decorator is missing service type",
+ messageFormat: "Decorator '{0}' must specify a ServiceType either via the generic attribute or the ServiceType property",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorMissingInnerConstructor = new(
+ id: "INJ0012",
+ title: "Decorator has no constructor accepting the inner service",
+ messageFormat: "Decorator '{0}' must expose a public constructor whose first parameter is of type '{1}' (or use Factory)",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorFactoryNotFound = new(
+ id: "INJ0013",
+ title: "Decorator factory method not found",
+ messageFormat: "Decorator factory method '{0}' was not found on type '{1}'",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorFactoryInvalidSignature = new(
+ id: "INJ0014",
+ title: "Decorator factory method has invalid signature",
+ messageFormat: "Decorator factory method '{0}' on type '{1}' must be static and accept (IServiceProvider, TService) for non-keyed or (IServiceProvider, object?, TService) for keyed decorators",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorOpenGenericKeyed = new(
+ id: "INJ0015",
+ title: "Keyed decoration is not supported for open-generic services",
+ messageFormat: "Decorator '{0}' targets open-generic service '{1}' and cannot be used with ServiceKey or AnyKey",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true,
+ customTags: [WellKnownDiagnosticTags.CompilationEnd]
+ );
+
+ public static readonly DiagnosticDescriptor DecoratorTargetNotRegistered = new(
+ id: "INJ0016",
+ title: "Decorator target service is not registered in this compilation",
+ messageFormat: "Decorator '{0}' targets service '{1}' but no matching registration was found; decoration will be skipped at runtime if the service is not registered elsewhere",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true,
+ customTags: [WellKnownDiagnosticTags.CompilationEnd]
+ );
}
diff --git a/src/Injectio.Generators/EquatableArray.cs b/src/Injectio.Generators/EquatableArray.cs
deleted file mode 100644
index 2df341e..0000000
--- a/src/Injectio.Generators/EquatableArray.cs
+++ /dev/null
@@ -1,60 +0,0 @@
-using System.Collections;
-using System.Diagnostics.CodeAnalysis;
-
-namespace Injectio.Generators;
-
-[ExcludeFromCodeCoverage]
-public readonly struct EquatableArray : IEquatable>, IEnumerable
- where T : IEquatable
-{
- public static readonly EquatableArray Empty = new();
-
-
- public EquatableArray() : this([]) { }
-
- public EquatableArray(T[] array) => Array = array ?? [];
-
- public EquatableArray(IEnumerable items) => Array = items.ToArray() ?? [];
-
-
- public T[] Array { get; }
-
- public int Count => Array.Length;
-
-
- public ReadOnlySpan AsSpan() => Array.AsSpan();
-
- public T[] AsArray() => Array;
-
-
- public static bool operator ==(EquatableArray left, EquatableArray right) => left.Equals(right);
-
- public static bool operator !=(EquatableArray left, EquatableArray right) => !left.Equals(right);
-
- public bool Equals(EquatableArray array) => Array.AsSpan().SequenceEqual(array.AsSpan());
-
- public override bool Equals(object? obj) => obj is EquatableArray array && Equals(this, array);
-
- public override int GetHashCode()
- {
- if (Array is not T[] array)
- return 0;
-
- var hashCode = 16777619;
-
- for (int i = 0; i < array.Length; i++)
- hashCode = unchecked((hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(array[i]));
-
- return hashCode;
- }
-
-
- IEnumerator IEnumerable.GetEnumerator() => (Array as IEnumerable).GetEnumerator();
-
- IEnumerator IEnumerable.GetEnumerator() => Array.GetEnumerator();
-
-
- public static implicit operator EquatableArray(T[] array) => new(array);
-
- public static implicit operator EquatableArray(List items) => new(items);
-}
diff --git a/src/Injectio.Generators/Extensions/StringExtensions.cs b/src/Injectio.Generators/Extensions/StringExtensions.cs
index 996d285..414a9cb 100644
--- a/src/Injectio.Generators/Extensions/StringExtensions.cs
+++ b/src/Injectio.Generators/Extensions/StringExtensions.cs
@@ -1,6 +1,8 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.RegularExpressions;
+using Injectio.Generators.Infrastructure;
+
using Microsoft.CodeAnalysis;
namespace Injectio.Generators.Extensions;
diff --git a/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs b/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs
new file mode 100644
index 0000000..8dae3d6
--- /dev/null
+++ b/src/Injectio.Generators/Infrastructure/CollectionBuilderAttribute.cs
@@ -0,0 +1,14 @@
+#pragma warning disable IDE0130 // Namespace does not match folder structure
+
+using System.ComponentModel;
+
+namespace System.Runtime.CompilerServices;
+
+[EditorBrowsable(EditorBrowsableState.Never)]
+[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Interface, Inherited = false)]
+internal sealed class CollectionBuilderAttribute(Type builderType, string methodName) : Attribute
+{
+ public Type BuilderType { get; } = builderType;
+
+ public string MethodName { get; } = methodName;
+}
diff --git a/src/Injectio.Generators/Infrastructure/EquatableArray.cs b/src/Injectio.Generators/Infrastructure/EquatableArray.cs
new file mode 100644
index 0000000..c632554
--- /dev/null
+++ b/src/Injectio.Generators/Infrastructure/EquatableArray.cs
@@ -0,0 +1,128 @@
+using System.Collections;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+
+namespace Injectio.Generators.Infrastructure;
+
+
+///
+/// Factory for creating instances from spans.
+///
+[ExcludeFromCodeCoverage]
+internal static class EquatableArray
+{
+ ///
+ /// Creates an from a .
+ ///
+ /// The element type.
+ /// The span of items to copy.
+ /// A new containing the items.
+ public static EquatableArray Create(ReadOnlySpan items)
+ where T : IEquatable
+ => new(items.ToArray());
+}
+
+///
+/// An immutable array wrapper that implements value-based equality, suitable for use
+/// in incremental generator models where structural comparison is required.
+///
+/// The element type, which must implement .
+[ExcludeFromCodeCoverage]
+[CollectionBuilder(typeof(EquatableArray), nameof(EquatableArray.Create))]
+public readonly struct EquatableArray : IEquatable>, IEnumerable
+ where T : IEquatable
+{
+ ///
+ /// An empty instance.
+ ///
+ public static readonly EquatableArray Empty = new();
+
+
+ ///
+ /// Initializes a new empty .
+ ///
+ public EquatableArray() : this([]) { }
+
+ ///
+ /// Initializes a new from an array.
+ ///
+ /// The array to wrap.
+ public EquatableArray(T[] array) => Array = array ?? [];
+
+ ///
+ /// Initializes a new from an enumerable.
+ ///
+ /// The items to copy into the array.
+ public EquatableArray(IEnumerable items) => Array = items.ToArray() ?? [];
+
+
+ ///
+ /// Gets the underlying array.
+ ///
+ public T[] Array { get; }
+
+ ///
+ /// Gets the number of elements in the array.
+ ///
+ public int Count => Array.Length;
+
+
+ ///
+ /// Returns the array as a .
+ ///
+ /// A read-only span over the array elements.
+ public ReadOnlySpan AsSpan() => Array.AsSpan();
+
+ ///
+ /// Returns the underlying array.
+ ///
+ /// The underlying array.
+ public T[] AsArray() => Array;
+
+
+ ///
+ public static bool operator ==(EquatableArray left, EquatableArray right) => left.Equals(right);
+
+ ///
+ public static bool operator !=(EquatableArray left, EquatableArray right) => !left.Equals(right);
+
+ ///
+ public bool Equals(EquatableArray array) => Array.AsSpan().SequenceEqual(array.AsSpan());
+
+ ///
+ public override bool Equals(object? obj) => obj is EquatableArray array && Equals(this, array);
+
+ ///
+ public override int GetHashCode()
+ {
+ if (Array is not T[] array)
+ return 0;
+
+ var hashCode = 16777619;
+
+ for (int i = 0; i < array.Length; i++)
+ hashCode = unchecked((hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(array[i]));
+
+ return hashCode;
+ }
+
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator() => (Array as IEnumerable).GetEnumerator();
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator() => Array.GetEnumerator();
+
+
+ ///
+ /// Implicitly converts an array to an .
+ ///
+ /// The array to wrap.
+ public static implicit operator EquatableArray(T[] array) => new(array);
+
+ ///
+ /// Implicitly converts a list to an .
+ ///
+ /// The list to copy.
+ public static implicit operator EquatableArray(List items) => new(items);
+}
diff --git a/src/Injectio.Generators/IndentedStringBuilder.cs b/src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs
similarity index 98%
rename from src/Injectio.Generators/IndentedStringBuilder.cs
rename to src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs
index 14ab0d9..0f0e66c 100644
--- a/src/Injectio.Generators/IndentedStringBuilder.cs
+++ b/src/Injectio.Generators/Infrastructure/IndentedStringBuilder.cs
@@ -1,7 +1,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Text;
-namespace Injectio.Generators;
+namespace Injectio.Generators.Infrastructure;
///
/// A thin wrapper over that adds indentation to each line built.
@@ -352,7 +352,7 @@ private void DoIndent()
_indentPending = false;
}
- private sealed class Indenter : IDisposable
+ private readonly struct Indenter : IDisposable
{
private readonly IndentedStringBuilder _stringBuilder;
@@ -367,7 +367,7 @@ public void Dispose()
=> _stringBuilder.DecrementIndent();
}
- private sealed class IndentSuspender : IDisposable
+ private readonly struct IndentSuspender : IDisposable
{
private readonly IndentedStringBuilder _stringBuilder;
private readonly byte _indent;
diff --git a/src/Injectio.Generators/IsExternalInit.cs b/src/Injectio.Generators/Infrastructure/IsExternalInit.cs
similarity index 66%
rename from src/Injectio.Generators/IsExternalInit.cs
rename to src/Injectio.Generators/Infrastructure/IsExternalInit.cs
index 75cd67b..91eb900 100644
--- a/src/Injectio.Generators/IsExternalInit.cs
+++ b/src/Injectio.Generators/Infrastructure/IsExternalInit.cs
@@ -1,3 +1,5 @@
+#pragma warning disable IDE0130 // Namespace does not match folder structure
+
using System.ComponentModel;
namespace System.Runtime.CompilerServices;
diff --git a/src/Injectio.Generators/Injectio.Generators.csproj b/src/Injectio.Generators/Injectio.Generators.csproj
index a686a32..95b64a8 100644
--- a/src/Injectio.Generators/Injectio.Generators.csproj
+++ b/src/Injectio.Generators/Injectio.Generators.csproj
@@ -3,7 +3,11 @@
netstandard2.0
true
- false
+ true
+ Injectio
+ true
+ false
+ true
cs
true
@@ -13,4 +17,14 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Injectio.Generators/KnownTypes.cs b/src/Injectio.Generators/KnownTypes.cs
index 77efad4..38048aa 100644
--- a/src/Injectio.Generators/KnownTypes.cs
+++ b/src/Injectio.Generators/KnownTypes.cs
@@ -10,19 +10,31 @@ public static class KnownTypes
public const string TransientAttributeShortName = "RegisterTransient";
public const string TransientAttributeTypeName = $"{TransientAttributeShortName}Attribute";
public const string TransientAttributeFullName = $"{AbstractionNamespace}.{TransientAttributeTypeName}";
+ public const string TransientAttributeFullNameT1 = $"{TransientAttributeFullName}`1";
+ public const string TransientAttributeFullNameT2 = $"{TransientAttributeFullName}`2";
public const string SingletonAttributeShortName = "RegisterSingleton";
public const string SingletonAttributeTypeName = $"{SingletonAttributeShortName}Attribute";
public const string SingletonAttributeFullName = $"{AbstractionNamespace}.{SingletonAttributeTypeName}";
+ public const string SingletonAttributeFullNameT1 = $"{SingletonAttributeFullName}`1";
+ public const string SingletonAttributeFullNameT2 = $"{SingletonAttributeFullName}`2";
public const string ScopedAttributeShortName = "RegisterScoped";
public const string ScopedAttributeTypeName = $"{ScopedAttributeShortName}Attribute";
public const string ScopedAttributeFullName = $"{AbstractionNamespace}.{ScopedAttributeTypeName}";
+ public const string ScopedAttributeFullNameT1 = $"{ScopedAttributeFullName}`1";
+ public const string ScopedAttributeFullNameT2 = $"{ScopedAttributeFullName}`2";
public const string ModuleAttributeShortName = "RegisterServices";
public const string ModuleAttributeTypeName = $"{ModuleAttributeShortName}Attribute";
public const string ModuleAttributeFullName = $"{AbstractionNamespace}.{ModuleAttributeTypeName}";
+ public const string DecoratorAttributeShortName = "RegisterDecorator";
+ public const string DecoratorAttributeTypeName = $"{DecoratorAttributeShortName}Attribute";
+ public const string DecoratorAttributeFullName = $"{AbstractionNamespace}.{DecoratorAttributeTypeName}";
+ public const string DecoratorAttributeFullNameT1 = $"{DecoratorAttributeFullName}`1";
+ public const string DecoratorAttributeFullNameT2 = $"{DecoratorAttributeFullName}`2";
+
public const string ServiceLifetimeSingletonShortName = "Singleton";
public const string ServiceLifetimeSingletonTypeName = $"ServiceLifetime.{ServiceLifetimeSingletonShortName}";
diff --git a/src/Injectio.Generators/Models/DecoratorRegistration.cs b/src/Injectio.Generators/Models/DecoratorRegistration.cs
new file mode 100644
index 0000000..4a65526
--- /dev/null
+++ b/src/Injectio.Generators/Models/DecoratorRegistration.cs
@@ -0,0 +1,14 @@
+using Injectio.Generators.Infrastructure;
+
+namespace Injectio.Generators.Models;
+
+public record DecoratorRegistration(
+ string DecoratorType,
+ string ServiceType,
+ string? ServiceKey,
+ bool IsAnyKey,
+ string? Factory,
+ int Order,
+ EquatableArray Tags,
+ bool IsOpenGeneric = false
+);
diff --git a/src/Injectio.Generators/MethodOptions.cs b/src/Injectio.Generators/Models/MethodOptions.cs
similarity index 65%
rename from src/Injectio.Generators/MethodOptions.cs
rename to src/Injectio.Generators/Models/MethodOptions.cs
index dc2aabb..01b3b75 100644
--- a/src/Injectio.Generators/MethodOptions.cs
+++ b/src/Injectio.Generators/Models/MethodOptions.cs
@@ -1,4 +1,4 @@
-namespace Injectio.Generators;
+namespace Injectio.Generators.Models;
public record MethodOptions(
string? Name,
diff --git a/src/Injectio.Generators/ModuleRegistration.cs b/src/Injectio.Generators/Models/ModuleRegistration.cs
similarity index 77%
rename from src/Injectio.Generators/ModuleRegistration.cs
rename to src/Injectio.Generators/Models/ModuleRegistration.cs
index b8ea6ce..a3e36b2 100644
--- a/src/Injectio.Generators/ModuleRegistration.cs
+++ b/src/Injectio.Generators/Models/ModuleRegistration.cs
@@ -1,4 +1,4 @@
-namespace Injectio.Generators;
+namespace Injectio.Generators.Models;
public record ModuleRegistration(
string ClassName,
diff --git a/src/Injectio.Generators/Models/RegistrationContext.cs b/src/Injectio.Generators/Models/RegistrationContext.cs
new file mode 100644
index 0000000..e4b8d10
--- /dev/null
+++ b/src/Injectio.Generators/Models/RegistrationContext.cs
@@ -0,0 +1,11 @@
+using Injectio.Generators.Infrastructure;
+
+namespace Injectio.Generators.Models;
+
+public record RegistrationContext(
+ EquatableArray ServiceRegistrations,
+ EquatableArray ModuleRegistrations,
+ EquatableArray DecoratorRegistrations,
+ string? AssemblyName,
+ MethodOptions? MethodOptions
+);
diff --git a/src/Injectio.Generators/ServiceRegistration.cs b/src/Injectio.Generators/Models/ServiceRegistration.cs
similarity index 80%
rename from src/Injectio.Generators/ServiceRegistration.cs
rename to src/Injectio.Generators/Models/ServiceRegistration.cs
index 7be2c98..5eec59f 100644
--- a/src/Injectio.Generators/ServiceRegistration.cs
+++ b/src/Injectio.Generators/Models/ServiceRegistration.cs
@@ -1,6 +1,8 @@
using System.Diagnostics.CodeAnalysis;
-namespace Injectio.Generators;
+using Injectio.Generators.Infrastructure;
+
+namespace Injectio.Generators.Models;
public record ServiceRegistration(
string Lifetime,
diff --git a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs
index dbec02e..a4600ee 100644
--- a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs
+++ b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs
@@ -20,7 +20,15 @@ public class ServiceRegistrationAnalyzer : DiagnosticAnalyzer
DiagnosticDescriptors.FactoryMethodInvalidSignature,
DiagnosticDescriptors.ServiceTypeMismatch,
DiagnosticDescriptors.AbstractImplementationType,
- DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass);
+ DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass,
+ DiagnosticDescriptors.DecoratorDoesNotImplementService,
+ DiagnosticDescriptors.DecoratorMissingServiceType,
+ DiagnosticDescriptors.DecoratorMissingInnerConstructor,
+ DiagnosticDescriptors.DecoratorFactoryNotFound,
+ DiagnosticDescriptors.DecoratorFactoryInvalidSignature,
+ DiagnosticDescriptors.DecoratorOpenGenericKeyed,
+ DiagnosticDescriptors.DecoratorTargetNotRegistered
+ );
public override void Initialize(AnalysisContext context)
{
@@ -29,6 +37,234 @@ public override void Initialize(AnalysisContext context)
context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method);
context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType);
+ context.RegisterCompilationStartAction(AnalyzeCompilation);
+ }
+
+ private static void AnalyzeCompilation(CompilationStartAnalysisContext context)
+ {
+ var decorators = new List();
+ var registeredServices = new HashSet(StringComparer.Ordinal);
+ var hasModule = false;
+
+ context.RegisterSymbolAction(symbolContext =>
+ {
+ if (symbolContext.Symbol is IMethodSymbol methodSymbol)
+ {
+ foreach (var attribute in methodSymbol.GetAttributes())
+ {
+ if (SymbolHelpers.IsMethodAttribute(attribute))
+ {
+ lock (registeredServices) hasModule = true;
+ break;
+ }
+ }
+ return;
+ }
+
+ if (symbolContext.Symbol is not INamedTypeSymbol classSymbol)
+ return;
+
+ if (classSymbol.IsStatic)
+ return;
+
+ var attributes = classSymbol.GetAttributes();
+
+ foreach (var attribute in attributes)
+ {
+ if (SymbolHelpers.IsDecoratorAttribute(attribute))
+ {
+ var info = ExtractDecoratorInfo(classSymbol, attribute);
+ if (info != null)
+ lock (decorators) decorators.Add(info);
+ continue;
+ }
+
+ if (!SymbolHelpers.IsKnownAttribute(attribute, out _))
+ continue;
+
+ CollectRegisteredServiceTypes(classSymbol, attribute, registeredServices);
+ }
+ }, SymbolKind.NamedType, SymbolKind.Method);
+
+ context.RegisterCompilationEndAction(endContext =>
+ {
+ foreach (var decorator in decorators)
+ {
+ // INJ0015 — open-generic + keyed combination
+ if (decorator.IsOpenGeneric && (decorator.HasServiceKey || decorator.IsAnyKey))
+ {
+ endContext.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorOpenGenericKeyed,
+ decorator.Location,
+ decorator.DecoratorType,
+ decorator.ServiceType));
+ }
+
+ // INJ0016 — target not registered (only when no module could register it dynamically)
+ if (!hasModule && !decorator.HasServiceKey && !decorator.IsAnyKey)
+ {
+ if (!registeredServices.Contains(decorator.ServiceType))
+ {
+ endContext.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorTargetNotRegistered,
+ decorator.Location,
+ decorator.DecoratorType,
+ decorator.ServiceType));
+ }
+ }
+ }
+ });
+ }
+
+ private sealed class DecoratorInfo
+ {
+ public string DecoratorType { get; set; } = string.Empty;
+ public string ServiceType { get; set; } = string.Empty;
+ public bool IsOpenGeneric { get; set; }
+ public bool HasServiceKey { get; set; }
+ public bool IsAnyKey { get; set; }
+ public Location Location { get; set; } = Location.None;
+ }
+
+ private static DecoratorInfo? ExtractDecoratorInfo(INamedTypeSymbol classSymbol, AttributeData attribute)
+ {
+ string? serviceType = null;
+ bool hasServiceKey = false;
+ bool isAnyKey = false;
+ bool isOpenGeneric = false;
+
+ var attributeClass = attribute.AttributeClass;
+ if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1)
+ {
+ var typeArgument = attributeClass.TypeArguments[0];
+ if (typeArgument is INamedTypeSymbol namedService && namedService.IsGenericType && namedService.IsUnboundGenericType)
+ isOpenGeneric = true;
+
+ serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+
+ foreach (var parameter in attribute.NamedArguments)
+ {
+ var name = parameter.Key;
+ var value = parameter.Value.Value;
+
+ switch (name)
+ {
+ case "ServiceType":
+ if (value is INamedTypeSymbol svc)
+ {
+ if (svc.IsGenericType && svc.IsUnboundGenericType)
+ isOpenGeneric = true;
+ serviceType = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ break;
+ case "ServiceKey":
+ hasServiceKey = value is not null;
+ break;
+ case "AnyKey":
+ if (value is bool b)
+ isAnyKey = b;
+ break;
+ }
+ }
+
+ if (serviceType is null)
+ return null;
+
+ var location = classSymbol.Locations.Length > 0 ? classSymbol.Locations[0] : Location.None;
+
+ return new DecoratorInfo
+ {
+ DecoratorType = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat),
+ ServiceType = serviceType,
+ IsOpenGeneric = isOpenGeneric,
+ HasServiceKey = hasServiceKey,
+ IsAnyKey = isAnyKey,
+ Location = location,
+ };
+ }
+
+ private static void CollectRegisteredServiceTypes(
+ INamedTypeSymbol classSymbol,
+ AttributeData attribute,
+ HashSet registeredServices)
+ {
+ string? implementationType = null;
+ string? registrationStrategy = null;
+ var localServiceTypes = new List();
+
+ var attributeClass = attribute.AttributeClass;
+ if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length)
+ {
+ for (var index = 0; index < attributeClass.TypeParameters.Length; index++)
+ {
+ var typeParameter = attributeClass.TypeParameters[index];
+ var typeArgument = attributeClass.TypeArguments[index];
+
+ if (typeParameter.Name == "TService" || index == 0)
+ localServiceTypes.Add(typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat));
+ else if (typeParameter.Name == "TImplementation" || index == 1)
+ implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ }
+
+ foreach (var parameter in attribute.NamedArguments)
+ {
+ var name = parameter.Key;
+ var value = parameter.Value.Value;
+ if (value is null) continue;
+
+ switch (name)
+ {
+ case "ServiceType":
+ if (value is INamedTypeSymbol svc)
+ localServiceTypes.Add(svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat));
+ break;
+ case "ImplementationType":
+ if (value is INamedTypeSymbol impl)
+ implementationType = impl.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ break;
+ case "Registration":
+ registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value);
+ break;
+ }
+ }
+
+ implementationType ??= classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+
+ if (registrationStrategy == null && localServiceTypes.Count == 0)
+ registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
+ bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
+ or null;
+
+ if (includeInterfaces)
+ {
+ foreach (var iface in classSymbol.AllInterfaces)
+ {
+ if (iface.ConstructedFrom.ToString() == "System.IEquatable")
+ continue;
+
+ var unbound = SymbolHelpers.ToUnboundGenericType(iface);
+ localServiceTypes.Add(unbound.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat));
+ }
+ }
+
+ bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName
+ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
+ or null;
+
+ if (includeSelf || localServiceTypes.Count == 0)
+ localServiceTypes.Add(implementationType);
+
+ lock (registeredServices)
+ {
+ foreach (var t in localServiceTypes)
+ registeredServices.Add(t);
+ }
}
private static void AnalyzeMethod(SymbolAnalysisContext context)
@@ -126,17 +362,249 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context)
foreach (var attribute in attributes)
{
- if (!SymbolHelpers.IsKnownAttribute(attribute, out _))
- continue;
-
var location = classSymbol.Locations.Length > 0
? classSymbol.Locations[0]
: Location.None;
+ if (SymbolHelpers.IsDecoratorAttribute(attribute))
+ {
+ AnalyzeDecoratorAttribute(context, classSymbol, attribute, location);
+ continue;
+ }
+
+ if (!SymbolHelpers.IsKnownAttribute(attribute, out _))
+ continue;
+
AnalyzeRegistrationAttribute(context, classSymbol, attribute, location);
}
}
+ private static void AnalyzeDecoratorAttribute(
+ SymbolAnalysisContext context,
+ INamedTypeSymbol classSymbol,
+ AttributeData attribute,
+ Location location)
+ {
+ string? serviceTypeName = null;
+ INamedTypeSymbol? serviceTypeSymbol = null;
+ string? factory = null;
+ bool hasServiceKey = false;
+ bool isAnyKey = false;
+
+ var attributeClass = attribute.AttributeClass;
+ if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1)
+ {
+ if (attributeClass.TypeArguments[0] is INamedTypeSymbol serviceArg)
+ {
+ serviceTypeSymbol = serviceArg;
+ serviceTypeName = serviceArg.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ }
+
+ foreach (var parameter in attribute.NamedArguments)
+ {
+ var name = parameter.Key;
+ var value = parameter.Value.Value;
+
+ switch (name)
+ {
+ case "ServiceType":
+ if (value is INamedTypeSymbol svc)
+ {
+ serviceTypeSymbol = svc;
+ serviceTypeName = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ break;
+ case "Factory":
+ factory = value?.ToString();
+ break;
+ case "ServiceKey":
+ hasServiceKey = value is not null;
+ break;
+ case "AnyKey":
+ if (value is bool b)
+ isAnyKey = b;
+ break;
+ }
+ }
+
+ // INJ0011 — missing service type
+ if (serviceTypeName is null)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorMissingServiceType,
+ location,
+ classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)));
+ return;
+ }
+
+ // INJ0010 — class does not implement service
+ var classTypeName = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (serviceTypeName != classTypeName)
+ {
+ var implementsService = false;
+
+ foreach (var iface in classSymbol.AllInterfaces)
+ {
+ var ifaceName = iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (ifaceName == serviceTypeName)
+ {
+ implementsService = true;
+ break;
+ }
+
+ var unboundIface = SymbolHelpers.ToUnboundGenericType(iface);
+ if (!SymbolEqualityComparer.Default.Equals(unboundIface, iface))
+ {
+ var unboundName = unboundIface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (unboundName == serviceTypeName)
+ {
+ implementsService = true;
+ break;
+ }
+ }
+ }
+
+ if (!implementsService)
+ {
+ var baseType = classSymbol.BaseType;
+ while (baseType is not null)
+ {
+ var baseName = baseType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (baseName == serviceTypeName) { implementsService = true; break; }
+
+ var unboundBase = SymbolHelpers.ToUnboundGenericType(baseType);
+ if (!SymbolEqualityComparer.Default.Equals(unboundBase, baseType))
+ {
+ var unboundBaseName = unboundBase.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (unboundBaseName == serviceTypeName) { implementsService = true; break; }
+ }
+
+ baseType = baseType.BaseType;
+ }
+ }
+
+ if (!implementsService)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorDoesNotImplementService,
+ location,
+ classTypeName,
+ serviceTypeName));
+ }
+ }
+
+ // INJ0012/13/14 — constructor or factory validation
+ if (factory.HasValue())
+ {
+ ValidateDecoratorFactory(context, classSymbol, factory!, hasServiceKey || isAnyKey, location);
+ }
+ else
+ {
+ var hasCompatibleCtor = false;
+ foreach (var ctor in classSymbol.InstanceConstructors)
+ {
+ if (ctor.DeclaredAccessibility == Accessibility.Private) continue;
+ if (ctor.Parameters.Length == 0) continue;
+
+ var firstParamType = ctor.Parameters[0].Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (firstParamType == serviceTypeName)
+ {
+ hasCompatibleCtor = true;
+ break;
+ }
+
+ if (ctor.Parameters[0].Type is INamedTypeSymbol paramNamed)
+ {
+ var unboundParam = SymbolHelpers.ToUnboundGenericType(paramNamed);
+ if (!SymbolEqualityComparer.Default.Equals(unboundParam, paramNamed))
+ {
+ var unboundName = unboundParam.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (unboundName == serviceTypeName)
+ {
+ hasCompatibleCtor = true;
+ break;
+ }
+ }
+ }
+
+ // any parameter matches?
+ foreach (var param in ctor.Parameters)
+ {
+ var paramType = param.Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (paramType == serviceTypeName)
+ {
+ hasCompatibleCtor = true;
+ break;
+ }
+ }
+
+ if (hasCompatibleCtor) break;
+ }
+
+ if (!hasCompatibleCtor && classSymbol.InstanceConstructors.Length > 0)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorMissingInnerConstructor,
+ location,
+ classTypeName,
+ serviceTypeName));
+ }
+ }
+ }
+
+ private static void ValidateDecoratorFactory(
+ SymbolAnalysisContext context,
+ INamedTypeSymbol classSymbol,
+ string factoryMethodName,
+ bool isKeyed,
+ Location location)
+ {
+ var className = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ var members = classSymbol.GetMembers(factoryMethodName);
+ var factoryMethods = new List();
+
+ foreach (var member in members)
+ {
+ if (member is IMethodSymbol method)
+ factoryMethods.Add(method);
+ }
+
+ if (factoryMethods.Count == 0)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorFactoryNotFound,
+ location,
+ factoryMethodName,
+ className));
+ return;
+ }
+
+ var expectedParamCount = isKeyed ? 3 : 2;
+
+ foreach (var method in factoryMethods)
+ {
+ if (!method.IsStatic) continue;
+ if (method.Parameters.Length != expectedParamCount) continue;
+
+ if (!SymbolHelpers.IsServiceProvider(method.Parameters[0])) continue;
+
+ if (isKeyed)
+ {
+ if (method.Parameters[1].Type.SpecialType != SpecialType.System_Object) continue;
+ // parameter[2] is the inner service — not strictly checked
+ }
+
+ return; // valid overload found
+ }
+
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.DecoratorFactoryInvalidSignature,
+ location,
+ factoryMethodName,
+ className));
+ }
+
private static void AnalyzeRegistrationAttribute(
SymbolAnalysisContext context,
INamedTypeSymbol classSymbol,
diff --git a/src/Injectio.Generators/ServiceRegistrationContext.cs b/src/Injectio.Generators/ServiceRegistrationContext.cs
deleted file mode 100644
index 2451cce..0000000
--- a/src/Injectio.Generators/ServiceRegistrationContext.cs
+++ /dev/null
@@ -1,10 +0,0 @@
-using System.Diagnostics.CodeAnalysis;
-
-using Microsoft.CodeAnalysis;
-
-namespace Injectio.Generators;
-
-public record ServiceRegistrationContext(
- EquatableArray? ServiceRegistrations = null,
- EquatableArray? ModuleRegistrations = null
-);
diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs
index d4be12c..717d348 100644
--- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs
+++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs
@@ -1,8 +1,11 @@
using System.Collections.Immutable;
+using System.IO;
using System.Text;
using System.Text.RegularExpressions;
using Injectio.Generators.Extensions;
+using Injectio.Generators.Infrastructure;
+using Injectio.Generators.Models;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
@@ -16,18 +19,111 @@ public class ServiceRegistrationGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
- // find all classes and methods with attributes
- var registrations = context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: SyntacticPredicate,
- transform: SemanticTransform
- )
- .Where(static context =>
- context is not null
- && (context.ServiceRegistrations?.Count > 0 || context.ModuleRegistrations?.Count > 0)
- )
- .Collect()
- .WithTrackingName("Registrations");
+ context.RegisterPostInitializationOutput(static ctx =>
+ {
+ ctx.AddEmbeddedAttributeDefinition();
+
+ ctx.AddSource(
+ hintName: "Injectio.Attributes.g.cs",
+ sourceText: SourceText.From(LoadEmbedded("Attributes.cs"), Encoding.UTF8)
+ );
+ ctx.AddSource(
+ hintName: "Injectio.Extensions.g.cs",
+ sourceText: SourceText.From(LoadEmbedded("Extensions.cs"), Encoding.UTF8)
+ );
+ });
+
+ // separate pipeline per attribute, each returning its specific model
+ var transientRegistrations = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.TransientAttributeFullName,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName))
+ .WithTrackingName("TransientRegistrations");
+
+ var transientRegistrationsT1 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.TransientAttributeFullNameT1,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName))
+ .WithTrackingName("TransientRegistrationsT1");
+
+ var transientRegistrationsT2 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.TransientAttributeFullNameT2,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeTransientFullName))
+ .WithTrackingName("TransientRegistrationsT2");
+
+ var scopedRegistrations = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.ScopedAttributeFullName,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName))
+ .WithTrackingName("ScopedRegistrations");
+
+ var scopedRegistrationsT1 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.ScopedAttributeFullNameT1,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName))
+ .WithTrackingName("ScopedRegistrationsT1");
+
+ var scopedRegistrationsT2 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.ScopedAttributeFullNameT2,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeScopedFullName))
+ .WithTrackingName("ScopedRegistrationsT2");
+
+ var singletonRegistrations = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.SingletonAttributeFullName,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName))
+ .WithTrackingName("SingletonRegistrations");
+
+ var singletonRegistrationsT1 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.SingletonAttributeFullNameT1,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName))
+ .WithTrackingName("SingletonRegistrationsT1");
+
+ var singletonRegistrationsT2 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.SingletonAttributeFullNameT2,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformServiceRegistration(ctx, KnownTypes.ServiceLifetimeSingletonFullName))
+ .WithTrackingName("SingletonRegistrationsT2");
+
+ var decoratorRegistrations = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.DecoratorAttributeFullName,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformDecoratorRegistration(ctx))
+ .WithTrackingName("DecoratorRegistrations");
+
+ var decoratorRegistrationsT1 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.DecoratorAttributeFullNameT1,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformDecoratorRegistration(ctx))
+ .WithTrackingName("DecoratorRegistrationsT1");
+
+ var decoratorRegistrationsT2 = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.DecoratorAttributeFullNameT2,
+ predicate: static (node, _) => IsNonAbstractNonStaticType(node),
+ transform: static (ctx, _) => TransformDecoratorRegistration(ctx))
+ .WithTrackingName("DecoratorRegistrationsT2");
+
+ var moduleRegistrations = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ KnownTypes.ModuleAttributeFullName,
+ predicate: static (node, _) => node is MethodDeclarationSyntax,
+ transform: static (ctx, _) => TransformModuleRegistration(ctx))
+ .WithTrackingName("ModuleRegistrations");
// include compilation options
var assemblyName = context.CompilationProvider
@@ -35,7 +131,7 @@ context is not null
.WithTrackingName("AssemblyName");
// include config options
- var methodName = context.AnalyzerConfigOptionsProvider
+ var methodOptions = context.AnalyzerConfigOptionsProvider
.Select(static (c, _) =>
{
c.GlobalOptions.TryGetValue("build_property.InjectioName", out var methodName);
@@ -44,38 +140,98 @@ context is not null
})
.WithTrackingName("Options");
- var options = assemblyName.Combine(methodName);
- var generation = registrations.Combine(options);
+ // combine all service registration pipelines
+ var allServiceRegistrations = transientRegistrations
+ .Collect()
+ .Combine(transientRegistrationsT1.Collect())
+ .Combine(transientRegistrationsT2.Collect())
+ .Combine(scopedRegistrations.Collect())
+ .Combine(scopedRegistrationsT1.Collect())
+ .Combine(scopedRegistrationsT2.Collect())
+ .Combine(singletonRegistrations.Collect())
+ .Combine(singletonRegistrationsT1.Collect())
+ .Combine(singletonRegistrationsT2.Collect())
+ .Select(static (combined, _) =>
+ {
+ var ((((((((t, t1), t2), s), s1), s2), si), si1), si2) = combined;
+ var result = new List();
+ foreach (var arr in t) result.AddRange(arr);
+ foreach (var arr in t1) result.AddRange(arr);
+ foreach (var arr in t2) result.AddRange(arr);
+ foreach (var arr in s) result.AddRange(arr);
+ foreach (var arr in s1) result.AddRange(arr);
+ foreach (var arr in s2) result.AddRange(arr);
+ foreach (var arr in si) result.AddRange(arr);
+ foreach (var arr in si1) result.AddRange(arr);
+ foreach (var arr in si2) result.AddRange(arr);
+ return new EquatableArray(result);
+ })
+ .WithTrackingName("AllServiceRegistrations");
+
+ // combine all decorator registration pipelines
+ var allDecoratorRegistrations = decoratorRegistrations
+ .Collect()
+ .Combine(decoratorRegistrationsT1.Collect())
+ .Combine(decoratorRegistrationsT2.Collect())
+ .Select(static (combined, _) =>
+ {
+ var ((d, d1), d2) = combined;
+ var result = new List();
+ foreach (var arr in d) result.AddRange(arr.AsArray());
+ foreach (var arr in d1) result.AddRange(arr.AsArray());
+ foreach (var arr in d2) result.AddRange(arr.AsArray());
+ return new EquatableArray(result);
+ })
+ .WithTrackingName("AllDecoratorRegistrations");
+
+ // combine all pipelines into a single generation context
+ var generation = allServiceRegistrations
+ .Combine(moduleRegistrations.Collect())
+ .Combine(allDecoratorRegistrations)
+ .Combine(assemblyName)
+ .Combine(methodOptions)
+ .Select(static (combined, _) =>
+ {
+ var ((((serviceRegs, moduleRegs), decoratorRegs), asmName), options) = combined;
+
+ return new RegistrationContext(
+ ServiceRegistrations: serviceRegs,
+ ModuleRegistrations: new EquatableArray(moduleRegs.Where(static m => m is not null).Cast().ToList()),
+ DecoratorRegistrations: decoratorRegs,
+ AssemblyName: asmName,
+ MethodOptions: options
+ );
+ })
+ .WithTrackingName("Generation");
context.RegisterSourceOutput(generation, ExecuteGeneration);
}
- private void ExecuteGeneration(
- SourceProductionContext sourceContext,
- (ImmutableArray Registrations, (string? AssemblyName, MethodOptions? MethodOptions) Options) source)
+ private void ExecuteGeneration(SourceProductionContext sourceContext, RegistrationContext source)
{
- var serviceRegistrations = source.Registrations
- .SelectMany(m => m?.ServiceRegistrations ?? Array.Empty())
- .Where(m => m is not null)
- .ToArray();
+ var serviceRegistrations = source.ServiceRegistrations.AsArray();
- var moduleRegistrations = source.Registrations
- .SelectMany(m => m?.ModuleRegistrations ?? Array.Empty())
- .Where(m => m is not null)
+ var moduleRegistrations = source.ModuleRegistrations.AsArray();
+
+ var decoratorRegistrations = source.DecoratorRegistrations
+ .OrderBy(m => m.ServiceType, StringComparer.Ordinal)
+ .ThenBy(m => m.Order)
+ .ThenBy(m => m.DecoratorType, StringComparer.Ordinal)
.ToArray();
// compute extension method name
- var methodName = source.Options.MethodOptions?.Name;
+ var methodName = source.MethodOptions?.Name;
if (methodName.IsNullOrWhiteSpace())
- methodName = Regex.Replace(source.Options.AssemblyName, "\\W", "");
+ methodName = Regex.Replace(source.AssemblyName ?? string.Empty, "\\W", "");
- var methodInternal = source.Options.MethodOptions?.Internal;
+ var methodInternal = source.MethodOptions?.Internal;
// generate registration method
var result = ServiceRegistrationWriter.GenerateExtensionClass(
moduleRegistrations,
serviceRegistrations,
- source.Options.AssemblyName,
+ decoratorRegistrations,
+ source.AssemblyName,
methodName,
methodInternal);
@@ -83,91 +239,184 @@ private void ExecuteGeneration(
sourceContext.AddSource("Injectio.g.cs", SourceText.From(result, Encoding.UTF8));
}
- private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken cancellationToken)
+ private static string LoadEmbedded(string name)
{
- return syntaxNode switch
- {
- ClassDeclarationSyntax { AttributeLists.Count: > 0 } declaration =>
- !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword)
- && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword),
-
- RecordDeclarationSyntax { AttributeLists.Count: > 0 } declaration =>
- !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword)
- && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword),
-
- MemberDeclarationSyntax { AttributeLists.Count: > 0 } declaration =>
- !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword),
+ var assembly = typeof(ServiceRegistrationGenerator).Assembly;
+ var resourceName = "Injectio.Generators.Embedded." + name;
+ using var stream = assembly.GetManifestResourceStream(resourceName)
+ ?? throw new InvalidOperationException($"Embedded resource not found: {resourceName}");
+ using var reader = new StreamReader(stream);
+ return reader.ReadToEnd();
+ }
- _ => false,
- };
+ private static bool IsNonAbstractNonStaticType(SyntaxNode node)
+ {
+ return node is TypeDeclarationSyntax declaration
+ and (ClassDeclarationSyntax or RecordDeclarationSyntax)
+ && !declaration.Modifiers.Any(SyntaxKind.AbstractKeyword)
+ && !declaration.Modifiers.Any(SyntaxKind.StaticKeyword);
}
- private static ServiceRegistrationContext? SemanticTransform(GeneratorSyntaxContext context, CancellationToken cancellationToken)
+ private static EquatableArray TransformServiceRegistration(GeneratorAttributeSyntaxContext context, string serviceLifetime)
{
- return context.Node switch
+ if (context.TargetSymbol is not INamedTypeSymbol classSymbol)
+ return EquatableArray.Empty;
+
+ // find matching attributes for this specific lifetime
+ var results = new List();
+ foreach (var attribute in context.Attributes)
{
- ClassDeclarationSyntax => SemanticTransformClass(context),
- RecordDeclarationSyntax => SemanticTransformClass(context),
- MethodDeclarationSyntax => SemanticTransformMethod(context),
- _ => null
- };
+ var registration = CreateServiceRegistration(classSymbol, attribute, serviceLifetime);
+ if (registration is not null)
+ results.Add(registration);
+ }
+
+ return new EquatableArray(results);
}
- private static ServiceRegistrationContext? SemanticTransformMethod(GeneratorSyntaxContext context)
+ private static EquatableArray TransformDecoratorRegistration(GeneratorAttributeSyntaxContext context)
{
- if (context.Node is not MethodDeclarationSyntax methodDeclaration)
- return null;
+ if (context.TargetSymbol is not INamedTypeSymbol classSymbol)
+ return EquatableArray.Empty;
- var methodSymbol = context.SemanticModel.GetDeclaredSymbol(methodDeclaration);
- if (methodSymbol is null)
- return null;
+ var results = new List();
+ foreach (var attribute in context.Attributes)
+ {
+ var decorator = CreateDecoratorRegistration(classSymbol, attribute);
+ if (decorator is not null)
+ results.Add(decorator);
+ }
- // make sure attribute is for registration
- var attributes = methodSymbol.GetAttributes();
- var isKnown = attributes.Any(SymbolHelpers.IsMethodAttribute);
- if (!isKnown)
+ return new EquatableArray(results);
+ }
+
+ private static ModuleRegistration? TransformModuleRegistration(GeneratorAttributeSyntaxContext context)
+ {
+ if (context.TargetSymbol is not IMethodSymbol methodSymbol)
return null;
var (isValid, hasTagCollection) = ValidateMethod(methodSymbol);
if (!isValid)
return null;
- var registration = new ModuleRegistration
- (
+ return new ModuleRegistration(
ClassName: methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat),
MethodName: methodSymbol.Name,
IsStatic: methodSymbol.IsStatic,
HasTagCollection: hasTagCollection
);
-
- return new ServiceRegistrationContext(ModuleRegistrations: new[] { registration });
}
- private static ServiceRegistrationContext? SemanticTransformClass(GeneratorSyntaxContext context)
+ private static DecoratorRegistration? CreateDecoratorRegistration(INamedTypeSymbol classSymbol, AttributeData attribute)
{
- if (context.Node is not (TypeDeclarationSyntax declaration and (ClassDeclarationSyntax or RecordDeclarationSyntax)))
- return null;
+ string? serviceType = null;
+ string? implementationType = null;
+ string? serviceKey = null;
+ bool isAnyKey = false;
+ string? factory = null;
+ int order = 0;
+ var tags = new HashSet();
+ bool isOpenGeneric = false;
- var classSymbol = context.SemanticModel.GetDeclaredSymbol(declaration);
- if (classSymbol is null)
- return null;
+ var attributeClass = attribute.AttributeClass;
+ if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length)
+ {
+ for (var index = 0; index < attributeClass.TypeParameters.Length; index++)
+ {
+ var typeParameter = attributeClass.TypeParameters[index];
+ var typeArgument = attributeClass.TypeArguments[index];
- var attributes = classSymbol.GetAttributes();
+ if (typeParameter.Name == "TService" || index == 0)
+ {
+ isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol);
+ serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ else if (typeParameter.Name == "TImplementation" || index == 1)
+ {
+ isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol);
+ implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ }
+ }
+
+ foreach (var parameter in attribute.NamedArguments)
+ {
+ var name = parameter.Key;
+ var value = parameter.Value.Value;
+
+ if (string.IsNullOrEmpty(name))
+ continue;
- // support multiple register attributes on a class
- var registrations = new List();
+ switch (name)
+ {
+ case "ServiceType":
+ if (value is INamedTypeSymbol serviceTypeSymbol)
+ {
+ isOpenGeneric = isOpenGeneric || IsOpenGeneric(serviceTypeSymbol);
+ serviceType = serviceTypeSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ else if (value != null)
+ {
+ serviceType = value.ToString();
+ }
+ break;
+ case "ImplementationType":
+ if (value is INamedTypeSymbol implSymbol)
+ {
+ isOpenGeneric = isOpenGeneric || IsOpenGeneric(implSymbol);
+ implementationType = implSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ else if (value != null)
+ {
+ implementationType = value.ToString();
+ }
+ break;
+ case "ServiceKey":
+ serviceKey = parameter.Value.ToCSharpString();
+ break;
+ case "AnyKey":
+ if (value is bool anyKey)
+ isAnyKey = anyKey;
+ break;
+ case "Factory":
+ factory = value?.ToString();
+ break;
+ case "Order":
+ if (value is int orderValue)
+ order = orderValue;
+ break;
+ case "Tags":
+ if (value is string tagsText)
+ {
+ foreach (var tag in tagsText.Split(',', ';'))
+ {
+ if (tag.HasValue())
+ tags.Add(tag.Trim());
+ }
+ }
+ break;
+ }
+ }
- foreach (var attribute in attributes)
+ if (implementationType.IsNullOrWhiteSpace())
{
- var registration = CreateServiceRegistration(classSymbol, attribute);
- if (registration is not null)
- registrations.Add(registration);
+ var unboundType = SymbolHelpers.ToUnboundGenericType(classSymbol);
+ isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundType);
+ implementationType = unboundType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
}
- if (registrations.Count == 0)
+ if (serviceType.IsNullOrWhiteSpace())
return null;
- return new ServiceRegistrationContext(ServiceRegistrations: registrations.ToArray());
+ return new DecoratorRegistration(
+ DecoratorType: implementationType!,
+ ServiceType: serviceType!,
+ ServiceKey: serviceKey,
+ IsAnyKey: isAnyKey,
+ Factory: factory,
+ Order: order,
+ Tags: tags.ToArray(),
+ IsOpenGeneric: isOpenGeneric);
}
private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbol methodSymbol)
@@ -199,12 +448,8 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
return (false, false);
}
- private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute)
+ private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute, string serviceLifetime)
{
- // check for known attribute
- if (!SymbolHelpers.IsKnownAttribute(attribute, out var serviceLifetime))
- return null;
-
// defaults
var serviceTypes = new HashSet();
string implementationType = null!;
@@ -308,6 +553,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName
or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
if (includeInterfaces)
{
foreach (var implementedInterface in classSymbol.AllInterfaces)
@@ -328,6 +574,7 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName
or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
if (includeSelf || serviceTypes.Count == 0)
serviceTypes.Add(implementationType!);
diff --git a/src/Injectio.Generators/ServiceRegistrationWriter.cs b/src/Injectio.Generators/ServiceRegistrationWriter.cs
index a393807..26f9aa6 100644
--- a/src/Injectio.Generators/ServiceRegistrationWriter.cs
+++ b/src/Injectio.Generators/ServiceRegistrationWriter.cs
@@ -1,4 +1,6 @@
using Injectio.Generators.Extensions;
+using Injectio.Generators.Infrastructure;
+using Injectio.Generators.Models;
namespace Injectio.Generators;
@@ -11,6 +13,15 @@ public static string GenerateExtensionClass(
string? assemblyName,
string? methodName,
string? methodInternal)
+ => GenerateExtensionClass(moduleRegistrations, serviceRegistrations, Array.Empty(), assemblyName, methodName, methodInternal);
+
+ public static string GenerateExtensionClass(
+ IReadOnlyList moduleRegistrations,
+ IReadOnlyList serviceRegistrations,
+ IReadOnlyList decoratorRegistrations,
+ string? assemblyName,
+ string? methodName,
+ string? methodInternal)
{
var codeBuilder = new IndentedStringBuilder();
codeBuilder
@@ -66,6 +77,11 @@ public static string GenerateExtensionClass(
WriteRegistration(codeBuilder, serviceRegistration);
}
+ foreach (var decoratorRegistration in decoratorRegistrations)
+ {
+ WriteDecorator(codeBuilder, decoratorRegistration);
+ }
+
codeBuilder
.AppendLine("return serviceCollection;")
.DecrementIndent()
@@ -320,6 +336,165 @@ private static void WriteServiceGeneric(
.AppendLine();
}
+ private static void WriteDecorator(
+ IndentedStringBuilder codeBuilder,
+ DecoratorRegistration decorator)
+ {
+ if (decorator.Tags.Count > 0)
+ {
+ codeBuilder
+ .Append("if (tagSet.Count == 0 || tagSet.Intersect(new[] { ");
+
+ bool wroteTag = false;
+ foreach (var tag in decorator.Tags)
+ {
+ if (wroteTag)
+ codeBuilder.Append(", ");
+
+ codeBuilder
+ .Append("\"")
+ .Append(tag)
+ .Append("\"");
+
+ wroteTag = true;
+ }
+
+ codeBuilder
+ .AppendLine(" }).Any())")
+ .AppendLine("{")
+ .IncrementIndent();
+ }
+
+ var serviceType = decorator.ServiceType;
+ var decoratorType = decorator.DecoratorType;
+ bool hasServiceKey = decorator.ServiceKey.HasValue();
+ bool isKeyed = hasServiceKey || decorator.IsAnyKey;
+
+ // resolve the service key expression passed to the helper
+ string keyExpression;
+ if (decorator.IsAnyKey)
+ keyExpression = "global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey";
+ else if (hasServiceKey)
+ keyExpression = decorator.ServiceKey!;
+ else
+ keyExpression = "null";
+
+ if (decorator.IsOpenGeneric)
+ {
+ codeBuilder
+ .Append("global::Injectio.Extensions.DecorationExtensions.DecorateOpenGeneric(")
+ .AppendLine()
+ .IncrementIndent()
+ .AppendLine("serviceCollection,")
+ .Append("typeof(")
+ .AppendIf("global::", !serviceType.StartsWith("global::"))
+ .Append(serviceType)
+ .AppendLine("),")
+ .Append("typeof(")
+ .AppendIf("global::", !decoratorType.StartsWith("global::"))
+ .Append(decoratorType)
+ .AppendLine(")")
+ .DecrementIndent()
+ .AppendLine(");")
+ .AppendLine();
+ }
+ else if (isKeyed)
+ {
+ codeBuilder
+ .Append("global::Injectio.Extensions.DecorationExtensions.DecorateKeyed<")
+ .AppendIf("global::", !serviceType.StartsWith("global::"))
+ .Append(serviceType)
+ .AppendLine(">(")
+ .IncrementIndent()
+ .AppendLine("serviceCollection,")
+ .Append(keyExpression)
+ .AppendLine(",");
+
+ WriteDecoratorFactory(codeBuilder, decorator, isKeyed: true);
+
+ codeBuilder
+ .AppendLine()
+ .DecrementIndent()
+ .AppendLine(");")
+ .AppendLine();
+ }
+ else
+ {
+ codeBuilder
+ .Append("global::Injectio.Extensions.DecorationExtensions.Decorate<")
+ .AppendIf("global::", !serviceType.StartsWith("global::"))
+ .Append(serviceType)
+ .AppendLine(">(")
+ .IncrementIndent()
+ .AppendLine("serviceCollection,");
+
+ WriteDecoratorFactory(codeBuilder, decorator, isKeyed: false);
+
+ codeBuilder
+ .AppendLine()
+ .DecrementIndent()
+ .AppendLine(");")
+ .AppendLine();
+ }
+
+ if (decorator.Tags.Count > 0)
+ {
+ codeBuilder
+ .DecrementIndent()
+ .AppendLine("}")
+ .AppendLine();
+ }
+ }
+
+ private static void WriteDecoratorFactory(
+ IndentedStringBuilder codeBuilder,
+ DecoratorRegistration decorator,
+ bool isKeyed)
+ {
+ var serviceType = decorator.ServiceType;
+ var decoratorType = decorator.DecoratorType;
+ var qualifiedService = serviceType.StartsWith("global::") ? serviceType : "global::" + serviceType;
+ var qualifiedDecorator = decoratorType.StartsWith("global::") ? decoratorType : "global::" + decoratorType;
+
+ if (decorator.Factory.HasValue())
+ {
+ bool hasNamespace = decorator.Factory!.Contains(".");
+ var factoryTarget = hasNamespace ? decorator.Factory! : qualifiedDecorator + "." + decorator.Factory;
+
+ if (isKeyed)
+ {
+ codeBuilder
+ .Append("static (serviceProvider, serviceKey, inner) => ")
+ .Append(factoryTarget)
+ .Append("(serviceProvider, serviceKey, inner)");
+ }
+ else
+ {
+ codeBuilder
+ .Append("static (serviceProvider, inner) => ")
+ .Append(factoryTarget)
+ .Append("(serviceProvider, inner)");
+ }
+ }
+ else
+ {
+ if (isKeyed)
+ {
+ codeBuilder
+ .Append("static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<")
+ .Append(qualifiedDecorator)
+ .Append(">(serviceProvider, inner)");
+ }
+ else
+ {
+ codeBuilder
+ .Append("static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<")
+ .Append(qualifiedDecorator)
+ .Append(">(serviceProvider, inner)");
+ }
+ }
+ }
+
public static string GetServiceCollectionMethod(string duplicateStrategy)
{
return duplicateStrategy switch
diff --git a/src/Injectio.Generators/SymbolHelpers.cs b/src/Injectio.Generators/SymbolHelpers.cs
index 54cc867..04cfc4d 100644
--- a/src/Injectio.Generators/SymbolHelpers.cs
+++ b/src/Injectio.Generators/SymbolHelpers.cs
@@ -61,6 +61,19 @@ public static bool IsScopedAttribute(AttributeData attribute)
};
}
+ public static bool IsDecoratorAttribute(AttributeData attribute)
+ {
+ return attribute?.AttributeClass is
+ {
+ Name: KnownTypes.DecoratorAttributeShortName or KnownTypes.DecoratorAttributeTypeName,
+ ContainingNamespace:
+ {
+ Name: "Attributes",
+ ContainingNamespace.Name: "Injectio"
+ }
+ };
+ }
+
public static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime)
{
if (IsSingletonAttribute(attribute))
diff --git a/src/Injectio.Shared/Attributes.cs b/src/Injectio.Shared/Attributes.cs
new file mode 100644
index 0000000..52746b0
--- /dev/null
+++ b/src/Injectio.Shared/Attributes.cs
@@ -0,0 +1,382 @@
+//
+#nullable enable
+
+namespace Injectio.Attributes
+{
+ ///
+ /// Service registration duplicate strategies
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ internal enum DuplicateStrategy
+ {
+ ///
+ /// Skips registrations for services that already exists.
+ ///
+ Skip = 0,
+ ///
+ /// Replaces existing service registrations.
+ ///
+ Replace = 1,
+ ///
+ /// Appends a new registration for existing services.
+ ///
+ Append = 2
+ }
+
+ ///
+ /// Register service type strategies
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ internal enum RegistrationStrategy
+ {
+ ///
+ /// Registers each matching concrete type as itself
+ ///
+ Self = 0,
+ ///
+ /// Registers each matching concrete type as all of its implemented interfaces.
+ ///
+ ImplementedInterfaces = 1,
+ ///
+ /// Registers each matching concrete type as all of its implemented interfaces and itself
+ ///
+ SelfWithInterfaces = 2,
+ ///
+ /// Registers each matching concrete type as all of its implemented interfaces and itself.
+ /// For the interfaces a proxy-factory resolves the service from its type-name, so only one instance is created per lifetime
+ ///
+ /// For open-generic registrations, this behaves like
+ SelfWithProxyFactory = 3
+ }
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal abstract class RegisterAttribute : global::System.Attribute
+ {
+ ///
+ /// The of the service
+ ///
+ ///
+ public global::System.Type? ServiceType { get; set; }
+
+ ///
+ /// The that implements the service. If not set, the class the interface is on will be used.
+ ///
+ ///
+ public global::System.Type? ImplementationType { get; set; }
+
+ ///
+ /// Gets or sets the key of the service.
+ ///
+ /// The service key.
+ ///
+ public object? ServiceKey { get; set; }
+
+ ///
+ /// Name of a factory method to create new instances of the service implementation
+ ///
+ ///
+ /// The method signature must be Func{IServiceProvider, object} and must be static
+ ///
+ ///
+ public string? Factory { get; set; }
+
+ ///
+ /// Gets or sets the duplicate.
+ ///
+ ///
+ /// The duplicate.
+ ///
+ public DuplicateStrategy Duplicate { get; set; }
+
+ ///
+ /// Gets or sets the registration.
+ ///
+ ///
+ /// The registration.
+ ///
+ public RegistrationStrategy Registration { get; set; }
+
+ ///
+ /// Gets or sets the comma delimited list of service registration tags.
+ ///
+ ///
+ /// The comma delimited list of service registration tags.
+ ///
+ public string? Tags { get; set; }
+ }
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a scoped service
+ ///
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterScoped]
+ /// public class ScopedService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterScopedAttribute : RegisterAttribute;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a singleton service
+ ///
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterSingleton]
+ /// public class SingletonService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterSingletonAttribute : RegisterAttribute;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a transient service
+ ///
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterTransient]
+ /// public class TransientService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterTransientAttribute : RegisterAttribute;
+
+ /// Attribute to indicate the method should be called to register services
+ /// use the RegisterServices attribute
+ ///
+ /// public class RegistrationModule
+ /// {
+ /// [RegisterServices]
+ /// public static void Register(IServiceCollection services)
+ /// {
+ /// services.TryAddTransient<IModuleService, ModuleService>();
+ /// }
+ /// }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Method)]
+ internal class RegisterServicesAttribute : global::System.Attribute;
+
+ ///
+ /// Attribute to indicate the target class should be registered as a decorator for an existing service.
+ /// The decorator wraps the previously registered service implementation and inherits its .
+ ///
+ /// Decorate IService with a logging wrapper
+ ///
+ /// [RegisterDecorator(ServiceType = typeof(IService))]
+ /// public class LoggingDecorator : IService
+ /// {
+ /// public LoggingDecorator(IService inner) { }
+ /// }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
+ internal class RegisterDecoratorAttribute : global::System.Attribute
+ {
+ ///
+ /// The of the service to decorate.
+ ///
+ public global::System.Type? ServiceType { get; set; }
+
+ ///
+ /// The that implements the decorator. If not set, the class the attribute is on will be used.
+ ///
+ public global::System.Type? ImplementationType { get; set; }
+
+ ///
+ /// Gets or sets the key of the keyed service to decorate.
+ /// Leave unset (and false) to decorate the non-keyed registration.
+ ///
+ public object? ServiceKey { get; set; }
+
+ ///
+ /// When true, the decorator is applied to every keyed registration of ,
+ /// regardless of its key. Equivalent to KeyedService.AnyKey.
+ ///
+ public bool AnyKey { get; set; }
+
+ ///
+ /// Name of a static factory method to construct the decorator.
+ ///
+ ///
+ /// The factory signature must be (IServiceProvider, TService) -> TService for non-keyed services
+ /// or (IServiceProvider, object?, TService) -> TService for keyed services.
+ ///
+ public string? Factory { get; set; }
+
+ ///
+ /// Gets or sets the order in which the decorator is applied. Lower values are applied first (innermost).
+ ///
+ public int Order { get; set; }
+
+ ///
+ /// Gets or sets the comma delimited list of registration tags.
+ ///
+ public string? Tags { get; set; }
+ }
+
+#if NET7_0_OR_GREATER
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a scoped service
+ ///
+ /// The type of the service to add.
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterScoped<IService>]
+ /// public class ScopedService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterScopedAttribute : RegisterScopedAttribute
+ where TService : class;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a scoped service
+ ///
+ /// The type of the service to add.
+ /// The type of the implementation to use.
+ /// Register the ScopedService class as an implementation for IService
+ ///
+ /// [RegisterScoped<IService, ScopedService>]
+ /// public class ScopedService: IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterScopedAttribute : RegisterScopedAttribute
+ where TService : class
+ where TImplementation : class, TService;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a singleton service
+ ///
+ /// The type of the service to add.
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterSingleton<IService>]
+ /// public class SingletonService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterSingletonAttribute : RegisterSingletonAttribute
+ where TService : class;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a singleton service
+ ///
+ /// The type of the service to add.
+ /// The type of the implementation to use.
+ /// Register the SingletonService class as an implementation for IService
+ ///
+ /// [RegisterSingleton<IService, SingletonService>]
+ /// public class SingletonService: IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterSingletonAttribute : RegisterSingletonAttribute
+ where TService : class
+ where TImplementation : class, TService;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a transient service
+ ///
+ /// The type of the service to add.
+ /// Register the class as an implementation for IService
+ ///
+ /// [RegisterTransient<IService>]
+ /// public class TransientService : IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ internal class RegisterTransientAttribute : RegisterTransientAttribute
+ where TService : class;
+
+ ///
+ /// Attribute to indicate the target class should be register for dependency injection as a transient service
+ ///
+ /// The type of the service to add.
+ /// The type of the implementation to use.
+ /// Register the TransientService class as an implementation for IService
+ ///
+ /// [RegisterTransient<IService, TransientService>]
+ /// public class TransientService: IService { }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true)]
+ internal class RegisterTransientAttribute : RegisterTransientAttribute
+ where TService : class
+ where TImplementation : class, TService;
+
+ ///
+ /// Attribute to indicate the target class should be registered as a decorator for .
+ ///
+ /// The type of the service to decorate.
+ ///
+ ///
+ /// [RegisterDecorator<IService>]
+ /// public class LoggingDecorator : IService
+ /// {
+ /// public LoggingDecorator(IService inner) { }
+ /// }
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
+ internal class RegisterDecoratorAttribute : RegisterDecoratorAttribute
+ where TService : class;
+
+ ///
+ /// Attribute to indicate the target class should be registered as a decorator for
+ /// using as the decorator implementation.
+ ///
+ /// The type of the service to decorate.
+ /// The type of the decorator implementation.
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")]
+ [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
+ internal class RegisterDecoratorAttribute : RegisterDecoratorAttribute
+ where TService : class
+ where TImplementation : class, TService;
+#endif
+}
diff --git a/src/Injectio.Shared/Extensions.cs b/src/Injectio.Shared/Extensions.cs
new file mode 100644
index 0000000..6554b1d
--- /dev/null
+++ b/src/Injectio.Shared/Extensions.cs
@@ -0,0 +1,451 @@
+//
+#nullable enable
+
+namespace Injectio.Extensions
+{
+ ///
+ /// Provides extension methods for decorating registered services in an .
+ ///
+ ///
+ ///
+ /// The Decorator pattern wraps an existing service with additional behavior without modifying the
+ /// original implementation. Each decorator implements the same interface as the service it wraps, creating
+ /// a chain of responsibility:
+ ///
+ ///
+ /// These methods work by replacing the
+ /// in-place within the service collection. The original factory, instance, or implementation type is captured
+ /// so the inner service can still be resolved. The replacement descriptor preserves the original lifetime.
+ /// Calling a decoration method multiple times stacks decorators, with the last call becoming the outermost layer.
+ ///
+ /// Usage example — decorating with a factory:
+ ///
+ /// // Register the original service
+ /// services.AddScoped<INotificationService, EmailNotificationService>();
+ ///
+ /// // Decorate with logging
+ /// services.Decorate<INotificationService>(
+ /// (sp, inner) => new LoggingNotificationService(inner, sp.GetRequiredService<ILogger>()));
+ ///
+ /// // Decorate with retry (stacks on top of logging)
+ /// services.Decorate<INotificationService>(
+ /// (sp, inner) => new RetryNotificationService(inner));
+ ///
+ /// // Resolution chain: Retry → Logging → Email
+ ///
+ /// Usage example — decorating open generics:
+ ///
+ /// // Register closed-generic implementations
+ /// services.AddScoped<IRepository<Order>, OrderRepository>();
+ /// services.AddScoped<IRepository<Customer>, CustomerRepository>();
+ ///
+ /// // Decorate all IRepository<T> with caching
+ /// services.DecorateOpenGeneric(
+ /// typeof(IRepository<>), typeof(CachingRepository<>));
+ ///
+ ///
+ [global::Microsoft.CodeAnalysis.Embedded]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ internal static class DecorationExtensions
+ {
+ ///
+ /// Decorates all registrations of by wrapping each resolved instance with .
+ /// Keyed service registrations are skipped.
+ ///
+ /// The service type to decorate.
+ /// The service collection containing the registrations to decorate.
+ /// A factory that receives the service provider and the original instance, and returns the decorated instance.
+ /// The same instance for chaining.
+ ///
+ ///
+ /// Decoration replaces each existing for
+ /// in-place. The original descriptor's factory, instance, or implementation type is
+ /// captured and used to resolve the inner service, which is then passed to .
+ /// The replacement descriptor preserves the original registration's lifetime.
+ ///
+ ///
+ /// Multiple calls stack: each call wraps whatever factory is currently registered, enabling layered decoration.
+ /// Only non-keyed registrations are affected; use DecorateKeyed for keyed services.
+ ///
+ ///
+ public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate(
+ this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services,
+ global::System.Func decoratorFactory)
+ where TService : class
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+
+ if (decoratorFactory is null)
+ throw new global::System.ArgumentNullException(nameof(decoratorFactory));
+
+ var serviceType = typeof(TService);
+ int count = services.Count;
+
+ for (int i = 0; i < count; i++)
+ {
+ var descriptor = services[i];
+
+ // Only rewrite matching non-keyed descriptors for the requested service type.
+ if (descriptor.ServiceType != serviceType)
+ continue;
+
+ if (IsKeyedDescriptor(descriptor))
+ continue;
+
+ var lifetime = descriptor.Lifetime;
+
+ if (descriptor.ImplementationInstance is object instance)
+ {
+ // Wrap pre-built singleton/instance registrations without re-creating the inner service.
+ services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(
+ serviceType: serviceType,
+ factory: sp => decoratorFactory(sp, (TService)instance),
+ lifetime: lifetime
+ );
+
+ continue;
+ }
+
+ if (descriptor.ImplementationFactory is global::System.Func factory)
+ {
+ // Preserve factory semantics by resolving inner through the original factory.
+ services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(
+ serviceType: serviceType,
+ factory: sp => decoratorFactory(sp, (TService)factory(sp)),
+ lifetime: lifetime
+ );
+
+ continue;
+ }
+
+ var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType;
+
+ // Type registrations are activated first, then passed as the inner instance to the decorator.
+ services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(
+ serviceType: serviceType,
+ factory: sp => decoratorFactory(sp, (TService)global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType))!,
+ lifetime: lifetime
+ );
+ }
+
+ return services;
+ }
+
+ ///
+ /// Decorates all registrations of by wrapping each resolved instance
+ /// with an automatically constructed instance of .
+ /// Keyed service registrations are skipped.
+ ///
+ /// The service type to decorate.
+ /// The decorator type. Its constructor must accept a parameter for the inner service; additional parameters are resolved from the service provider.
+ /// The service collection containing the registrations to decorate.
+ /// The same instance for chaining.
+ ///
+ ///
+ /// This overload uses
+ /// to construct the decorator, passing the inner service instance as an explicit argument. Any remaining
+ /// constructor parameters are resolved from the .
+ ///
+ ///
+ public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate(
+ this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)
+ where TService : class
+ where TDecorator : class, TService
+ {
+ // Delegate to the factory-based overload so stacking and lifetime behavior stay centralized.
+ return services.Decorate(
+ (sp, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, inner));
+ }
+
+#if NET8_0_OR_GREATER
+ ///
+ /// Decorates all keyed registrations of matching
+ /// by wrapping each resolved instance with .
+ /// Pass to match all keys.
+ ///
+ /// The service type to decorate.
+ /// The service collection containing the registrations to decorate.
+ /// The key to match, or for all keys.
+ /// A factory that receives the service provider, the key, and the original instance, and returns the decorated instance.
+ /// The same instance for chaining.
+ ///
+ ///
+ /// This method iterates through all service descriptors for and replaces each
+ /// keyed registration whose key matches . The original keyed factory, instance,
+ /// or implementation type is captured and used to resolve the inner service, which is then passed to
+ /// along with the service provider and key. The replacement descriptor
+ /// preserves the original registration's lifetime and key.
+ ///
+ ///
+ /// Non-keyed registrations are skipped; use Decorate for those. Multiple calls stack, enabling
+ /// layered decoration of keyed services.
+ ///
+ ///
+ public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed(
+ this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services,
+ object? serviceKey,
+ global::System.Func decoratorFactory)
+ where TService : class
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+
+ if (decoratorFactory is null)
+ throw new global::System.ArgumentNullException(nameof(decoratorFactory));
+
+ var serviceType = typeof(TService);
+ bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey);
+ int count = services.Count;
+
+ for (int i = 0; i < count; i++)
+ {
+ var descriptor = services[i];
+
+ // Only rewrite matching keyed descriptors for the requested service type/key.
+ if (descriptor.ServiceType != serviceType)
+ continue;
+
+ if (!descriptor.IsKeyedService)
+ continue;
+
+ if (!anyKey && !global::System.Collections.Generic.EqualityComparer