未验证 提交 c7252a39 编写于 作者: G github-actions[bot] 提交者: GitHub

[release/8.0] Fix support of FromKeyedServicesAttribute in...

[release/8.0] Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateFactory (#92144)

* Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateFactory

* Addressing comment and adding a test

---------
Co-authored-by: NBenjamin Petit <bpetit@microsoft.com>
上级 7b324065
......@@ -36,7 +36,7 @@ public static class ActivatorUtilities
#endif
private static readonly MethodInfo GetServiceInfo =
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?>>((sp, t, r, c) => GetService(sp, t, r, c));
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?, object?>>((sp, t, r, c, k) => GetService(sp, t, r, c, k));
/// <summary>
/// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
......@@ -324,9 +324,9 @@ private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
return mc.Method;
}
private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue)
private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key)
{
object? service = sp.GetService(type);
object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key);
if (service is null && !hasDefaultValue)
{
ThrowHelperUnableToResolveService(type, requiredBy);
......@@ -361,10 +361,12 @@ private static void ThrowHelperUnableToResolveService(Type type, Type requiredBy
}
else
{
var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
var parameterTypeExpression = new Expression[] { serviceProvider,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(constructor.DeclaringType, typeof(Type)),
Expression.Constant(hasDefaultValue) };
Expression.Constant(hasDefaultValue),
Expression.Constant(keyAttribute?.Key) };
constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression);
}
......@@ -435,10 +437,10 @@ private static void ThrowHelperArgumentNullExceptionServiceProvider()
if (matchedArgCount == 0)
{
// All injected; use a fast path.
Type[] types = GetParameterTypes();
FactoryParameterContext[] parameters = GetFactoryParameterContext();
return useFixedValues ?
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider);
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider);
}
if (matchedArgCount == constructorParameters.Length)
......@@ -456,16 +458,6 @@ ObjectFactory InvokeCanonical()
(serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) :
(serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments);
}
Type[] GetParameterTypes()
{
Type[] types = new Type[constructorParameters.Length];
for (int i = 0; i < constructorParameters.Length; i++)
{
types[i] = constructorParameters[i].ParameterType;
}
return types;
}
#else
ParameterInfo[] constructorParameters = constructor.GetParameters();
if (constructorParameters.Length == 0)
......@@ -484,8 +476,15 @@ FactoryParameterContext[] GetFactoryParameterContext()
for (int i = 0; i < constructorParameters.Length; i++)
{
ParameterInfo constructorParameter = constructorParameters[i];
FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?)
Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue);
parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1);
parameters[i] = new FactoryParameterContext(
constructorParameter.ParameterType,
hasDefaultValue,
defaultValue,
parameterMap[i] ?? -1,
attr?.Key);
}
return parameters;
......@@ -495,18 +494,20 @@ FactoryParameterContext[] GetFactoryParameterContext()
private readonly struct FactoryParameterContext
{
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex)
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey)
{
ParameterType = parameterType;
HasDefaultValue = hasDefaultValue;
DefaultValue = defaultValue;
ArgumentIndex = argumentIndex;
ServiceKey = serviceKey;
}
public Type ParameterType { get; }
public bool HasDefaultValue { get; }
public object? DefaultValue { get; }
public int ArgumentIndex { get; }
public object? ServiceKey { get; }
}
private static void FindApplicableConstructor(
......@@ -825,39 +826,39 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
#if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters.
private static object ReflectionFactoryServiceOnlyFixed(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold);
Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold);
Debug.Assert(FixedArgumentThreshold == 4);
if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();
switch (parameterTypes.Length)
switch (parameters.Length)
{
case 1:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey));
case 2:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey));
case 3:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey));
case 4:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false),
GetService(serviceProvider, parameterTypes[3], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey),
GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey));
}
return null!;
......@@ -865,17 +866,17 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
private static object ReflectionFactoryServiceOnlySpan(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();
object?[] arguments = new object?[parameterTypes.Length];
for (int i = 0; i < parameterTypes.Length; i++)
object?[] arguments = new object?[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false);
arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey);
}
return invoker.Invoke(arguments.AsSpan());
......@@ -907,7 +908,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue);
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue);
case 2:
{
ref FactoryParameterContext parameter2 = ref parameters[1];
......@@ -920,7 +922,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
......@@ -928,7 +931,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue);
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue);
}
case 3:
{
......@@ -943,7 +947,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
......@@ -951,7 +956,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
......@@ -959,7 +965,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue);
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue);
}
case 4:
{
......@@ -975,7 +982,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
......@@ -983,7 +991,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
......@@ -991,7 +1000,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue,
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue,
((parameter4.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter4.ArgumentIndex]
......@@ -999,7 +1009,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter4.ParameterType,
declaringType,
parameter4.HasDefaultValue)) ?? parameter4.DefaultValue);
parameter4.HasDefaultValue,
parameter4.ServiceKey)) ?? parameter4.DefaultValue);
}
}
......@@ -1028,7 +1039,8 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}
return invoker.Invoke(constructorArguments.AsSpan());
......@@ -1078,7 +1090,8 @@ private static void ThrowHelperNullReferenceException()
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}
return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null);
......@@ -1099,5 +1112,17 @@ public static void ClearCache(Type[]? _)
}
}
#endif
private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey)
{
ThrowHelper.ThrowIfNull(provider);
if (provider is IKeyedServiceProvider keyedServiceProvider)
{
return keyedServiceProvider.GetKeyedService(type, serviceKey);
}
throw new InvalidOperationException(SR.KeyedServicesNotSupported);
}
}
}
......@@ -476,5 +476,41 @@ public ServiceProviderAccessor(IServiceProvider serviceProvider)
public IServiceProvider ServiceProvider { get; }
}
[Fact]
public void SimpleServiceKeyedResolution()
{
// Arrange
var services = new ServiceCollection();
services.AddKeyedTransient<ISimpleService, SimpleService>("simple");
services.AddKeyedTransient<ISimpleService, AnotherSimpleService>("another");
services.AddTransient<SimpleParentWithDynamicKeyedService>();
var provider = CreateServiceProvider(services);
var sut = provider.GetService<SimpleParentWithDynamicKeyedService>();
// Act
var result = sut!.GetService("simple");
// Assert
Assert.True(result.GetType() == typeof(SimpleService));
}
public class SimpleParentWithDynamicKeyedService
{
private readonly IServiceProvider _serviceProvider;
public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}
public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService<ISimpleService>(name)!;
}
public interface ISimpleService { }
public class SimpleService : ISimpleService { }
public class AnotherSimpleService : ISimpleService { }
}
}
......@@ -240,6 +240,100 @@ public void CreateFactory_CreatesFactoryMethod_5Types_5Injected()
Assert.NotNull(item.Z);
}
[ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
[InlineData(true)]
#if NETCOREAPP
[InlineData(false)]
#endif
public void CreateFactory_CreatesFactoryMethod_KeyedParams(bool useDynamicCode)
{
var options = new RemoteInvokeOptions();
if (!useDynamicCode)
{
DisableDynamicCode(options);
}
using var remoteHandle = RemoteExecutor.Invoke(static () =>
{
var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedC>(Type.EmptyTypes);
var services = new ServiceCollection();
services.AddSingleton(new A());
services.AddKeyedSingleton("b", new B());
services.AddKeyedSingleton("c", new C());
using var provider = services.BuildServiceProvider();
ClassWithAKeyedBKeyedC item = factory(provider, null);
Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedC>>(factory);
Assert.NotNull(item.A);
Assert.NotNull(item.B);
Assert.NotNull(item.C);
}, options);
}
[ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
[InlineData(true)]
#if NETCOREAPP
[InlineData(false)]
#endif
public void CreateFactory_CreatesFactoryMethod_KeyedParams_5Types(bool useDynamicCode)
{
var options = new RemoteInvokeOptions();
if (!useDynamicCode)
{
DisableDynamicCode(options);
}
using var remoteHandle = RemoteExecutor.Invoke(static () =>
{
var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedCSZ>(Type.EmptyTypes);
var services = new ServiceCollection();
services.AddSingleton(new A());
services.AddKeyedSingleton("b", new B());
services.AddKeyedSingleton("c", new C());
services.AddSingleton(new S());
services.AddSingleton(new Z());
using var provider = services.BuildServiceProvider();
ClassWithAKeyedBKeyedCSZ item = factory(provider, null);
Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedCSZ>>(factory);
Assert.NotNull(item.A);
Assert.NotNull(item.B);
Assert.NotNull(item.C);
}, options);
}
[ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
[InlineData(true)]
#if NETCOREAPP
[InlineData(false)]
#endif
public void CreateFactory_CreatesFactoryMethod_KeyedParams_1Injected(bool useDynamicCode)
{
var options = new RemoteInvokeOptions();
if (!useDynamicCode)
{
DisableDynamicCode(options);
}
using var remoteHandle = RemoteExecutor.Invoke(static () =>
{
var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedC>(new Type[] { typeof(A) });
var services = new ServiceCollection();
services.AddKeyedSingleton("b", new B());
services.AddKeyedSingleton("c", new C());
using var provider = services.BuildServiceProvider();
ClassWithAKeyedBKeyedC item = factory(provider, new object?[] { new A() });
Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedC>>(factory);
Assert.NotNull(item.A);
Assert.NotNull(item.B);
Assert.NotNull(item.C);
}, options);
}
[ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
[InlineData(true)]
#if NETCOREAPP
......@@ -527,6 +621,13 @@ internal class C { }
internal class S { }
internal class Z { }
internal class ClassWithAKeyedBKeyedC : ClassWithABC
{
public ClassWithAKeyedBKeyedC(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c)
: base(a, b, c)
{ }
}
internal class ClassWithABCS : ClassWithABC
{
public S S { get; }
......@@ -540,6 +641,13 @@ internal class ClassWithABCSZ : ClassWithABCS
public ClassWithABCSZ(A a, B b, C c, S s, Z z) : base(a, b, c, s) { Z = z; }
}
internal class ClassWithAKeyedBKeyedCSZ : ClassWithABCSZ
{
public ClassWithAKeyedBKeyedCSZ(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c, S s, Z z)
: base(a, b, c, s, z)
{ }
}
internal class ClassWithABC_FirstConstructorWithAttribute : ClassWithABC
{
[ActivatorUtilitiesConstructor]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册