From c0db07b3f3f093b5f4cd27e1f5e8aa54adad049d Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Wed, 30 Mar 2022 18:23:38 -0700 Subject: [PATCH] Minor refactoring in MarshallingAttributeInfoParser (#67325) --- .../ManualTypeMarshallingHelper.cs | 23 ++ .../MarshallingAttributeInfo.cs | 200 ++++++++---------- 2 files changed, 116 insertions(+), 107 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs index f636f96d9fa..24887b880ed 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs @@ -75,6 +75,29 @@ public static (bool hasAttribute, ITypeSymbol? managedType, CustomTypeMarshaller return (true, managedType, new CustomTypeMarshallerData(kind, direction, features, bufferSize)); } + /// + /// Get the supported for a marshaller type + /// + /// The marshaller type. + /// The mananged type that would be marshalled. + /// Supported + public static CustomTypeMarshallerPinning GetMarshallerPinningFeatures(ITypeSymbol marshallerType, ITypeSymbol? managedType) + { + CustomTypeMarshallerPinning pinning = CustomTypeMarshallerPinning.None; + + if (FindGetPinnableReference(marshallerType) is not null) + { + pinning |= CustomTypeMarshallerPinning.NativeType; + } + + if (managedType is not null && FindGetPinnableReference(managedType) is not null) + { + pinning |= CustomTypeMarshallerPinning.ManagedType; + } + + return pinning; + } + /// /// Resolve a non- to the correct managed type if is generic and is using any placeholder types. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 3e3e79b6c5a..a26bb93e3e3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -536,62 +536,29 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS } } - if (!isArrayType) + if (isArrayType) { - return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - } - - if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - return NoMarshallingInfo.Instance; - } - - MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; - if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) - { - elementMarshallingInfo = new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); - } - else - { - maxIndirectionDepthUsed = 1; - elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsed); - } - - INamedTypeSymbol? arrayMarshaller; + if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); + return NoMarshallingInfo.Instance; + } - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); - } - else - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); - } + MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; + if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) + { + elementMarshallingInfo = new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); + } + else + { + maxIndirectionDepthUsed = 1; + elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsed); + } - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. - return new MissingSupportCollectionMarshallingInfo(arraySizeInfo, elementMarshallingInfo); + return CreateArrayMarshallingInfo(elementType, arraySizeInfo, elementMarshallingInfo); } - var (_, _, customTypeMarshallerData) = ManualTypeMarshallingHelper.GetMarshallerShapeInfo(arrayMarshaller); - - Debug.Assert(customTypeMarshallerData is not null); - - ITypeSymbol? nativeValueType = ManualTypeMarshallingHelper.FindToNativeValueMethod(arrayMarshaller)?.ReturnType; - - return new NativeLinearCollectionMarshallingInfo( - NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), - NativeValueType: nativeValueType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeValueType) : null, - Direction: customTypeMarshallerData.Value.Direction, - MarshallingFeatures: customTypeMarshallerData.Value.Features, - PinningFeatures: CustomTypeMarshallerPinning.NativeType, - UseDefaultMarshalling: true, - BufferSize: customTypeMarshallerData.Value.BufferSize, - ElementCountInfo: arraySizeInfo, - ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType), - ElementMarshallingInfo: elementMarshallingInfo); + return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); } private MarshallingInfo CreateNativeMarshallingInfo( @@ -605,8 +572,6 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS ImmutableHashSet inspectedElements, ref int maxIndirectionDepthUsed) { - INamedTypeSymbol readOnlySpanOfT = _compilation.GetTypeByMetadataName(TypeNames.System_ReadOnlySpan_Metadata)!; - if (nativeType.IsUnboundGenericType) { if (isMarshalUsingAttribute) @@ -639,31 +604,21 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS return NoMarshallingInfo.Instance; } - CustomTypeMarshallerPinning pinning = CustomTypeMarshallerPinning.None; - - if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) - { - pinning |= CustomTypeMarshallerPinning.ManagedType; - } - - if (ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null) - { - pinning |= CustomTypeMarshallerPinning.NativeType; - } - - IMethodSymbol? toNativeValueMethod = ManualTypeMarshallingHelper.FindToNativeValueMethod(nativeType); - if (customTypeMarshallerData.Value.Kind == CustomTypeMarshallerKind.LinearCollection) { + INamedTypeSymbol readOnlySpanOfT = _compilation.GetTypeByMetadataName(TypeNames.System_ReadOnlySpan_Metadata)!; if (!ManualTypeMarshallingHelper.TryGetElementTypeFromLinearCollectionMarshaller(nativeType, readOnlySpanOfT, out ITypeSymbol elementType)) { _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CollectionNativeTypeMustHaveRequiredShapeMessage), nativeType.ToDisplayString()); return NoMarshallingInfo.Instance; } + CustomTypeMarshallerPinning pinning = ManualTypeMarshallingHelper.GetMarshallerPinningFeatures(nativeType, isMarshalUsingAttribute ? null : type); + IMethodSymbol? toNativeValueMethod = ManualTypeMarshallingHelper.FindToNativeValueMethod(nativeType); + ManagedTypeInfo? nativeValueType = toNativeValueMethod is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(toNativeValueMethod.ReturnType) : null; return new NativeLinearCollectionMarshallingInfo( ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), - toNativeValueMethod is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(toNativeValueMethod.ReturnType) : null, + nativeValueType, customTypeMarshallerData.Value.Direction, customTypeMarshallerData.Value.Features, pinning, @@ -674,14 +629,31 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed)); } + return CreateNativeMarshallingInfoForValue( + type, + nativeType, + attrData, + customTypeMarshallerData.Value, + allowPinningManagedType: !isMarshalUsingAttribute, + useDefaultMarshalling: !isMarshalUsingAttribute); + } + + private MarshallingInfo CreateNativeMarshallingInfoForValue( + ITypeSymbol type, + INamedTypeSymbol nativeType, + AttributeData attrData, + CustomTypeMarshallerData customTypeMarshallerData, + bool allowPinningManagedType, + bool useDefaultMarshalling) + { ManagedTypeInfo? bufferElementTypeInfo = null; - if (customTypeMarshallerData.Value.Features.HasFlag(CustomTypeMarshallerFeatures.CallerAllocatedBuffer)) + if (customTypeMarshallerData.Features.HasFlag(CustomTypeMarshallerFeatures.CallerAllocatedBuffer)) { ITypeSymbol? bufferElementType = null; INamedTypeSymbol spanOfT = _compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!; foreach (IMethodSymbol ctor in nativeType.Constructors) { - if (ManualTypeMarshallingHelper.IsCallerAllocatedSpanConstructor(ctor, type, spanOfT, customTypeMarshallerData.Value.Kind, out bufferElementType)) + if (ManualTypeMarshallingHelper.IsCallerAllocatedSpanConstructor(ctor, type, spanOfT, customTypeMarshallerData.Kind, out bufferElementType)) break; } @@ -694,15 +666,20 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS bufferElementTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(bufferElementType); } + CustomTypeMarshallerPinning pinning = ManualTypeMarshallingHelper.GetMarshallerPinningFeatures(nativeType, allowPinningManagedType ? type : null); + + IMethodSymbol? toNativeValueMethod = ManualTypeMarshallingHelper.FindToNativeValueMethod(nativeType); + ManagedTypeInfo? nativeValueType = toNativeValueMethod is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(toNativeValueMethod.ReturnType) : null; + return new NativeMarshallingAttributeInfo( ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), - toNativeValueMethod is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(toNativeValueMethod.ReturnType) : null, - customTypeMarshallerData.Value.Direction, - customTypeMarshallerData.Value.Features, + nativeValueType, + customTypeMarshallerData.Direction, + customTypeMarshallerData.Features, pinning, - UseDefaultMarshalling: !isMarshalUsingAttribute, + useDefaultMarshalling, bufferElementTypeInfo, - customTypeMarshallerData.Value.BufferSize); + customTypeMarshallerData.BufferSize); } private bool TryCreateTypeBasedMarshallingInfo( @@ -743,38 +720,8 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) { - INamedTypeSymbol? arrayMarshaller; - - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); - } - else - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); - } - - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. - marshallingInfo = new MissingSupportCollectionMarshallingInfo(parsedCountInfo, GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed)); - return true; - } - - var (_, _, customTypeMarshallerData) = ManualTypeMarshallingHelper.GetMarshallerShapeInfo(arrayMarshaller); - ITypeSymbol? valuePropertyType = ManualTypeMarshallingHelper.FindToNativeValueMethod(arrayMarshaller)?.ReturnType; - - marshallingInfo = new NativeLinearCollectionMarshallingInfo( - NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), - NativeValueType: valuePropertyType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valuePropertyType) : null, - Direction: CustomTypeMarshallerDirection.Ref, - MarshallingFeatures: CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, - PinningFeatures: CustomTypeMarshallerPinning.NativeType, - UseDefaultMarshalling: true, - customTypeMarshallerData.Value.BufferSize, - ElementCountInfo: parsedCountInfo, - ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType), - ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed)); + MarshallingInfo elementMarshallingInfo = GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed); + marshallingInfo = CreateArrayMarshallingInfo(elementType, parsedCountInfo, elementMarshallingInfo); return true; } @@ -820,6 +767,45 @@ private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashS return false; } + private MarshallingInfo CreateArrayMarshallingInfo( + ITypeSymbol elementType, + CountInfo countInfo, + MarshallingInfo elementMarshallingInfo) + { + INamedTypeSymbol? arrayMarshaller; + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); + } + else + { + arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + } + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. + return new MissingSupportCollectionMarshallingInfo(countInfo, elementMarshallingInfo); + } + + var (_, _, customTypeMarshallerData) = ManualTypeMarshallingHelper.GetMarshallerShapeInfo(arrayMarshaller); + Debug.Assert(customTypeMarshallerData is not null); + + ITypeSymbol? nativeValueType = ManualTypeMarshallingHelper.FindToNativeValueMethod(arrayMarshaller)?.ReturnType; + + return new NativeLinearCollectionMarshallingInfo( + NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), + NativeValueType: nativeValueType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeValueType) : null, + Direction: customTypeMarshallerData.Value.Direction, + MarshallingFeatures: customTypeMarshallerData.Value.Features, + PinningFeatures: CustomTypeMarshallerPinning.NativeType, + UseDefaultMarshalling: true, + customTypeMarshallerData.Value.BufferSize, + ElementCountInfo: countInfo, + ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType), + ElementMarshallingInfo: elementMarshallingInfo); + } + private MarshallingInfo GetBlittableMarshallingInfo(ITypeSymbol type) { if (type.TypeKind is TypeKind.Enum or TypeKind.Pointer or TypeKind.FunctionPointer -- GitLab