From bc9c37105c87b1f763692a33c0304e4e81fbbb92 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 29 Aug 2022 17:52:21 -0700 Subject: [PATCH] Refactor marshalling info parser to split overall logic from ordering (#72687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor marshalling info parser to split overall logic from ordering Split our marshalling attribute parser into separate classes for each attribute and category and make the parser itself only handle executing the various different stages of marshalling info parsing. All logic for parsing a given attribute is handled now by the separate classes. This PR allows the JS marshaller to reuse the core of our parsing and opt-in to more of the logic as they see fit. * Add doc comments for all of the new APIs * Fix failing test * Make the JS known managed type info hang off the JS marshalling info instead of inheriting from ManagedTypeInfo. ManagedTypeInfo is meant to represent just enough info from an ITypeSymbol that we can accurately generate code based on any language/typesystem rules. It is not meant for storing generator-specific marshalling info. * Create JSTypeInfo based on symbols, not type name string parsing. * Apply suggestions from code review Co-authored-by: Marek Fišera * PR feedback. * Add comments Co-authored-by: Marek Fišera --- .../FallbackJSMarshallingInfoProvider.cs | 22 + .../JSExportCodeGenerator.cs | 4 +- .../JSImportGenerator/JSGeneratorFactory.cs | 31 +- .../JSImportCodeGenerator.cs | 4 +- .../JSImportGenerator/JSImportStubContext.cs | 22 +- .../JSImportGenerator/JSManagedTypeInfo.cs | 277 +++---- .../JSMarshallAsAttributeInfoParser.cs | 67 ++ .../JSMarshallingAttributeInfoParser.cs | 86 -- .../JSImportGenerator/JSMarshallingInfo.cs | 16 +- .../Marshaling/FuncJSGenerator.cs | 4 +- .../Marshaling/TaskJSGenerator.cs | 14 +- .../ConvertToLibraryImportAnalyzer.cs | 4 +- .../LibraryImportGenerator.cs | 2 +- .../ArrayMarshallingInfoProvider.cs | 87 +++ .../BlittableTypeMarshallingInfoProvider.cs | 45 ++ .../BooleanMarshallingInfoProvider.cs | 28 + .../CharMarshallingInfoProvider.cs | 32 + .../CustomMarshallingInfoHelper.cs | 111 +++ .../MarshalAsAttributeParser.cs | 158 ++++ .../MarshalUsingAttributeParser.cs | 114 +++ .../Marshalling/MarshallerHelpers.cs | 14 +- .../MarshallingAttributeInfo.cs | 735 +----------------- .../MarshallingInfoParser.cs | 367 +++++++++ .../MethodSignatureElementInfoProvider.cs | 78 ++ .../NativeMarshallingAttributeParser.cs | 49 ++ .../SafeHandleMarshallingInfoProvider.cs | 65 ++ .../SignatureContext.cs | 34 +- .../StringMarshallingInfoProvider.cs | 96 +++ .../UseSiteAttributeProvider.cs | 101 +++ .../CompileFails.cs | 4 +- 30 files changed, 1656 insertions(+), 1015 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs delete mode 100644 src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs new file mode 100644 index 00000000000..58ed34812e7 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop.JavaScript +{ + /// + /// Always returns a JSMissingMarshallingInfo. + /// + internal sealed class FallbackJSMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => true; + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + return new JSMissingMarshallingInfo(JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs index 1318358bea4..0dba53f9bf0 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs @@ -40,9 +40,9 @@ internal sealed class JSExportCodeGenerator : JSCodeGenerator } // validate task + span mix - if (_marshallers.ManagedReturnMarshaller.TypeInfo.ManagedType is JSTaskTypeInfo) + if (_marshallers.ManagedReturnMarshaller.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSTaskTypeInfo)) { - BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.ManagedType is JSSpanTypeInfo); + BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSSpanTypeInfo)); if (spanArg != default) { marshallingNotSupportedCallback(spanArg.TypeInfo, new MarshallingNotSupportedException(spanArg.TypeInfo, _context) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs index f908872318b..7c000011bf0 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs @@ -7,6 +7,7 @@ using Microsoft.CodeAnalysis; using System.Collections.Generic; using System.Diagnostics; +using System.Data; namespace Microsoft.Interop.JavaScript { @@ -34,51 +35,51 @@ Exception fail(string failReason) } bool isToJs = info.ManagedIndex != TypePositionInfo.ReturnIndex ^ context is JSExportCodeContext; - switch (info) + switch (jsMarshalingInfo) { // invalid - case { ManagedType: JSInvalidTypeInfo }: + case { TypeInfo: JSInvalidTypeInfo }: throw new MarshallingNotSupportedException(info, context); // void - case { ManagedType: SpecialTypeInfo sd } when sd.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Discard: - case { ManagedType: SpecialTypeInfo sv } when sv.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Void: - case { ManagedType: SpecialTypeInfo sn } when sn.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.None: - case { ManagedType: SpecialTypeInfo sm } when sm.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Missing: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Discard }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Void }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.None }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Missing }: return new VoidGenerator(jsMarshalingInfo.JSType == JSTypeFlags.Void ? MarshalerType.Void : MarshalerType.Discard); // discard no void - case { } when jsMarshalingInfo.JSType == JSTypeFlags.Discard: + case { JSType: JSTypeFlags.Discard }: throw fail(SR.DiscardOnlyVoid); // primitive - case { ManagedType: JSSimpleTypeInfo simple }: + case { TypeInfo: JSSimpleTypeInfo simple }: return Create(info, isToJs, simple.KnownType, Array.Empty(), jsMarshalingInfo.JSType, Array.Empty(), fail); // nullable - case { ManagedType: JSNullableTypeInfo nullable }: + case { TypeInfo: JSNullableTypeInfo nullable }: return Create(info, isToJs, nullable.KnownType, new[] { nullable.ResultTypeInfo.KnownType }, jsMarshalingInfo.JSType, null, fail); // array - case { ManagedType: JSArrayTypeInfo array }: + case { TypeInfo: JSArrayTypeInfo array }: return Create(info, isToJs, array.KnownType, new[] { array.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // array segment - case { ManagedType: JSArraySegmentTypeInfo segment }: + case { TypeInfo: JSArraySegmentTypeInfo segment }: return Create(info, isToJs, segment.KnownType, new[] { segment.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // span - case { ManagedType: JSSpanTypeInfo span }: + case { TypeInfo: JSSpanTypeInfo span }: return Create(info, isToJs, span.KnownType, new[] { span.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // task - case { ManagedType: JSTaskTypeInfo task } when task.ResultTypeInfo is JSSimpleTypeInfo taskRes && taskRes.FullTypeName == "void": + case { TypeInfo: JSTaskTypeInfo(JSSimpleTypeInfo(KnownManagedType.Void)) task }: return Create(info, isToJs, task.KnownType, Array.Empty(), jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); - case { ManagedType: JSTaskTypeInfo task }: + case { TypeInfo: JSTaskTypeInfo task }: return Create(info, isToJs, task.KnownType, new[] { task.ResultTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // action + function - case { ManagedType: JSFunctionTypeInfo function }: + case { TypeInfo: JSFunctionTypeInfo function }: return Create(info, isToJs, function.KnownType, function.ArgsTypeInfo.Select(a => a.KnownType).ToArray(), jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); default: diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs index 7bed0e9df4e..9c10b08ef87 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs @@ -47,9 +47,9 @@ internal sealed class JSImportCodeGenerator : JSCodeGenerator } // validate task + span mix - if (_marshallers.ManagedReturnMarshaller.TypeInfo.ManagedType is JSTaskTypeInfo) + if (_marshallers.ManagedReturnMarshaller.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSTaskTypeInfo)) { - BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.ManagedType is JSSpanTypeInfo); + BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSSpanTypeInfo)); if (spanArg != default) { marshallingNotSupportedCallback(spanArg.TypeInfo, new MarshallingNotSupportedException(spanArg.TypeInfo, _context) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs index 055e81315e6..84d3cc2b692 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs @@ -145,19 +145,22 @@ static bool IsSkipLocalsInitAttribute(AttributeData a) private static (ImmutableArray, IMarshallingGeneratorFactory) GenerateTypeInformation(IMethodSymbol method, GeneratorDiagnostics diagnostics, StubEnvironment env) { - var jsMarshallingAttributeParser = new JSMarshallingAttributeInfoParser(env.Compilation, diagnostics, method); + ImmutableArray useSiteAttributeParsers = ImmutableArray.Create(new JSMarshalAsAttributeParser(env.Compilation)); + var jsMarshallingAttributeParser = new MarshallingInfoParser( + diagnostics, + new MethodSignatureElementInfoProvider(env.Compilation, diagnostics, method, useSiteAttributeParsers), + useSiteAttributeParsers, + ImmutableArray.Create(new JSMarshalAsAttributeParser(env.Compilation)), + ImmutableArray.Create(new FallbackJSMarshallingInfoProvider())); // Determine parameter and return types ImmutableArray.Builder typeInfos = ImmutableArray.CreateBuilder(); for (int i = 0; i < method.Parameters.Length; i++) { IParameterSymbol param = method.Parameters[i]; - MarshallingInfo marshallingInfo = NoMarshallingInfo.Instance; - MarshallingInfo jsMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes(), marshallingInfo); + MarshallingInfo jsMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); - var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); - typeInfo = JSTypeInfo.CreateForType(typeInfo, param.Type, jsMarshallingInfo, env.Compilation); - typeInfo = typeInfo with + var typeInfo = TypePositionInfo.CreateForParameter(param, jsMarshallingInfo, env.Compilation) with { ManagedIndex = i, NativeIndex = typeInfos.Count, @@ -165,12 +168,9 @@ private static (ImmutableArray, IMarshallingGeneratorFactory) typeInfos.Add(typeInfo); } - MarshallingInfo retMarshallingInfo = NoMarshallingInfo.Instance; - MarshallingInfo retJSMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), retMarshallingInfo); + MarshallingInfo retJSMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes()); - var retTypeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), retMarshallingInfo); - retTypeInfo = JSTypeInfo.CreateForType(retTypeInfo, method.ReturnType, retJSMarshallingInfo, env.Compilation); - retTypeInfo = retTypeInfo with + var retTypeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), retJSMarshallingInfo) { ManagedIndex = TypePositionInfo.ReturnIndex, NativeIndex = TypePositionInfo.ReturnIndex, diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs index 0eedd0aa873..0f9084a01e8 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs @@ -4,185 +4,192 @@ using System; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Microsoft.Interop.JavaScript { - internal abstract record JSTypeInfo(string FullTypeName, string DiagnosticFormattedName, KnownManagedType KnownType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName) + internal abstract record JSTypeInfo(KnownManagedType KnownType) { - public static ManagedTypeInfo CreateJSTypeInfoForTypeSymbol(ITypeSymbol type) + public static JSTypeInfo CreateJSTypeInfoForTypeSymbol(ITypeSymbol type) { string fullTypeName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - if (fullTypeName == "void") + switch (type) { - return SpecialTypeInfo.Void; - } - string diagnosticFormattedName = type.ToDisplayString(); - return CreateJSTypeInfoForTypeSymbol(fullTypeName, diagnosticFormattedName); - } - - public static ManagedTypeInfo CreateJSTypeInfoForTypeSymbol(string fullTypeName, string diagnosticFormattedName) - { - switch (fullTypeName.Trim()) - { - case "global::System.Void": - case "void": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Void); - case "global::System.Boolean": - case "bool": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Boolean); - case "global::System.Byte": - case "byte": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Byte); - case "global::System.Char": - case "char": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Char); - case "global::System.Int16": - case "short": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int16); - case "global::System.Int32": - case "int": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int32); - case "global::System.Int64": - case "long": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int64); - case "global::System.Single": - case "float": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Single); - case "global::System.Double": - case "double": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Double); - case "global::System.IntPtr": - case "nint": - case "void*": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.IntPtr); - case "global::System.DateTime": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.DateTime); - case "global::System.DateTimeOffset": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.DateTimeOffset); - case "global::System.Exception": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Exception); - case "global::System.Object": - case "object": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Object); - case "global::System.String": - case "string": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.String); - case "global::System.Runtime.InteropServices.JavaScript.JSObject": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.JSObject); + case { SpecialType: SpecialType.System_Void }: + return new JSSimpleTypeInfo(KnownManagedType.Void) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.VoidKeyword)) + }; + case { SpecialType: SpecialType.System_Boolean }: + return new JSSimpleTypeInfo(KnownManagedType.Boolean) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.BoolKeyword)) + }; + case { SpecialType: SpecialType.System_Byte }: + return new JSSimpleTypeInfo(KnownManagedType.Byte) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)) + }; + case { SpecialType: SpecialType.System_Char }: + return new JSSimpleTypeInfo(KnownManagedType.Char) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)) + }; + case { SpecialType: SpecialType.System_Int16 }: + return new JSSimpleTypeInfo(KnownManagedType.Int16) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ShortKeyword)) + }; + case { SpecialType: SpecialType.System_Int32 }: + return new JSSimpleTypeInfo(KnownManagedType.Int32) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.IntKeyword)) + }; + case { SpecialType: SpecialType.System_Int64 }: + return new JSSimpleTypeInfo(KnownManagedType.Int64) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.LongKeyword)) + }; + case { SpecialType: SpecialType.System_Single }: + return new JSSimpleTypeInfo(KnownManagedType.Single) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.FloatKeyword)) + }; + case { SpecialType: SpecialType.System_Double }: + return new JSSimpleTypeInfo(KnownManagedType.Double) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.DoubleKeyword)) + }; + case { SpecialType: SpecialType.System_IntPtr }: + case IPointerTypeSymbol { PointedAtType.SpecialType: SpecialType.System_Void }: + return new JSSimpleTypeInfo(KnownManagedType.IntPtr) + { + Syntax = SyntaxFactory.IdentifierName("nint") + }; + case { SpecialType: SpecialType.System_DateTime }: + return new JSSimpleTypeInfo(KnownManagedType.DateTime) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case ITypeSymbol when fullTypeName == "global::System.DateTimeOffset": + return new JSSimpleTypeInfo(KnownManagedType.DateTimeOffset) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case ITypeSymbol when fullTypeName == "global::System.Exception": + return new JSSimpleTypeInfo(KnownManagedType.Exception) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case { SpecialType: SpecialType.System_Object }: + return new JSSimpleTypeInfo(KnownManagedType.Object) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ObjectKeyword)) + }; + case { SpecialType: SpecialType.System_String }: + return new JSSimpleTypeInfo(KnownManagedType.String) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)) + }; + case ITypeSymbol when fullTypeName == "global::System.Runtime.InteropServices.JavaScript.JSObject": + return new JSSimpleTypeInfo(KnownManagedType.JSObject) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; //nullable - case string ftn when ftn.EndsWith("?"): - var ut = fullTypeName.Remove(fullTypeName.Length - 1); - if (CreateJSTypeInfoForTypeSymbol(ut, diagnosticFormattedName) is JSSimpleTypeInfo uti) + case INamedTypeSymbol { ConstructedFrom.SpecialType: SpecialType.System_Nullable_T } nullable: + if (CreateJSTypeInfoForTypeSymbol(nullable.TypeArguments[0]) is JSSimpleTypeInfo uti) { - return new JSNullableTypeInfo(fullTypeName, diagnosticFormattedName, uti); + return new JSNullableTypeInfo(uti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // array - case string ftn when ftn.EndsWith("[]"): - var et = fullTypeName.Remove(fullTypeName.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(et, diagnosticFormattedName) is JSSimpleTypeInfo eti) + case IArrayTypeSymbol { IsSZArray: true, ElementType: ITypeSymbol elementType }: + if (CreateJSTypeInfoForTypeSymbol(elementType) is JSSimpleTypeInfo eti) { - return new JSArrayTypeInfo(fullTypeName, diagnosticFormattedName, eti); + return new JSArrayTypeInfo(eti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // task - case Constants.TaskGlobal: - return new JSTaskTypeInfo(fullTypeName, diagnosticFormattedName, (JSSimpleTypeInfo)CreateJSTypeInfoForTypeSymbol("void", diagnosticFormattedName)); - case string ft when ft.StartsWith(Constants.TaskGlobal): - var rt = fullTypeName.Substring(Constants.TaskGlobal.Length + 1, fullTypeName.Length - Constants.TaskGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(rt, diagnosticFormattedName) is JSSimpleTypeInfo rti) + case ITypeSymbol when fullTypeName == Constants.TaskGlobal: + return new JSTaskTypeInfo(new JSSimpleTypeInfo(KnownManagedType.Void, SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.VoidKeyword)))); + case INamedTypeSymbol { TypeArguments.Length: 1 } taskType when fullTypeName.StartsWith(Constants.TaskGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(taskType.TypeArguments[0]) is JSSimpleTypeInfo rti) { - return new JSTaskTypeInfo(fullTypeName, diagnosticFormattedName, rti); + return new JSTaskTypeInfo(rti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // span - case string ft when ft.StartsWith(Constants.SpanGlobal): - var st = fullTypeName.Substring(Constants.SpanGlobal.Length + 1, fullTypeName.Length - Constants.SpanGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(st, diagnosticFormattedName) is JSSimpleTypeInfo sti) + case INamedTypeSymbol { TypeArguments.Length: 1 } spanType when fullTypeName.StartsWith(Constants.SpanGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(spanType.TypeArguments[0]) is JSSimpleTypeInfo sti) { - return new JSSpanTypeInfo(fullTypeName, diagnosticFormattedName, sti); + return new JSSpanTypeInfo(sti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // array segment - case string ft when ft.StartsWith(Constants.ArraySegmentGlobal): - var gt = fullTypeName.Substring(Constants.ArraySegmentGlobal.Length + 1, fullTypeName.Length - Constants.ArraySegmentGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(gt, diagnosticFormattedName) is JSSimpleTypeInfo gti) + case INamedTypeSymbol { TypeArguments.Length: 1 } arraySegmentType when fullTypeName.StartsWith(Constants.ArraySegmentGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(arraySegmentType.TypeArguments[0]) is JSSimpleTypeInfo gti) { - return new JSArraySegmentTypeInfo(fullTypeName, diagnosticFormattedName, gti); + return new JSArraySegmentTypeInfo(gti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // action - case Constants.ActionGlobal: - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, true, Array.Empty()); - case string ft when ft.StartsWith(Constants.ActionGlobal): - var argNames = fullTypeName.Substring(Constants.ActionGlobal.Length + 1, fullTypeName.Length - Constants.ActionGlobal.Length - 2); - if (!argNames.Contains("<")) - { - var ga = argNames.Split(',') - .Select(argName => CreateJSTypeInfoForTypeSymbol(argName, diagnosticFormattedName) as JSSimpleTypeInfo) - .ToArray(); - if (ga.Any(x => x == null)) - { - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); - } - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, true, ga); + case ITypeSymbol when fullTypeName == Constants.ActionGlobal: + return new JSFunctionTypeInfo(true, Array.Empty()); + case INamedTypeSymbol actionType when fullTypeName.StartsWith(Constants.ActionGlobal, StringComparison.Ordinal): + var argumentTypes = actionType.TypeArguments + .Select(arg => CreateJSTypeInfoForTypeSymbol(arg) as JSSimpleTypeInfo) + .ToArray(); + if (argumentTypes.Any(x => x is null)) + { + return new JSInvalidTypeInfo(); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSFunctionTypeInfo(true, argumentTypes); // function - case string ft when ft.StartsWith(Constants.FuncGlobal): - var fargNames = fullTypeName.Substring(Constants.FuncGlobal.Length + 1, fullTypeName.Length - Constants.FuncGlobal.Length - 2); - if (!fargNames.Contains("<")) - { - var ga = fargNames.Split(',') - .Select(argName => CreateJSTypeInfoForTypeSymbol(argName, diagnosticFormattedName) as JSSimpleTypeInfo) - .ToArray(); - if (ga.Any(x => x == null)) - { - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); - } - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, false, ga); + case INamedTypeSymbol funcType when fullTypeName.StartsWith(Constants.FuncGlobal, StringComparison.Ordinal): + var signatureTypes = funcType.TypeArguments + .Select(argName => CreateJSTypeInfoForTypeSymbol(argName) as JSSimpleTypeInfo) + .ToArray(); + if (signatureTypes.Any(x => x is null)) + { + return new JSInvalidTypeInfo(); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSFunctionTypeInfo(false, signatureTypes); default: - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); } } + } - public static TypePositionInfo CreateForType(TypePositionInfo inner, ITypeSymbol type, MarshallingInfo jsMarshallingInfo, Compilation compilation) - { - ManagedTypeInfo jsTypeInfo = CreateJSTypeInfoForTypeSymbol(type); - var typeInfo = new TypePositionInfo(jsTypeInfo, jsMarshallingInfo) - { - InstanceIdentifier = inner.InstanceIdentifier, - RefKind = inner.RefKind, - RefKindSyntax = inner.RefKindSyntax, - ByValueContentsMarshalKind = inner.ByValueContentsMarshalKind - }; + internal sealed record JSInvalidTypeInfo() : JSSimpleTypeInfo(KnownManagedType.None); - return typeInfo; + internal record JSSimpleTypeInfo(KnownManagedType KnownType) : JSTypeInfo(KnownType) + { + public JSSimpleTypeInfo(KnownManagedType knownType, TypeSyntax syntax) + : this(knownType) + { + Syntax = syntax; } + public TypeSyntax Syntax { get; init; } } - internal sealed record JSInvalidTypeInfo(string FullTypeName, string DiagnosticFormattedName) : JSSimpleTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.None); - - internal record JSSimpleTypeInfo(string FullTypeName, string DiagnosticFormattedName, KnownManagedType KnownType) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownType); - - internal sealed record JSArrayTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Array); + internal sealed record JSArrayTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.Array); - internal sealed record JSSpanTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Span); + internal sealed record JSSpanTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.Span); - internal sealed record JSArraySegmentTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.ArraySegment); + internal sealed record JSArraySegmentTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.ArraySegment); - internal sealed record JSTaskTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Task); + internal sealed record JSTaskTypeInfo(JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(KnownManagedType.Task); - internal sealed record JSNullableTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Nullable); + internal sealed record JSNullableTypeInfo(JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(KnownManagedType.Nullable); - internal sealed record JSFunctionTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsAction, JSSimpleTypeInfo[] ArgsTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, (IsAction ? KnownManagedType.Action : KnownManagedType.Function)); + internal sealed record JSFunctionTypeInfo(bool IsAction, JSSimpleTypeInfo[] ArgsTypeInfo) : JSTypeInfo(IsAction ? KnownManagedType.Action : KnownManagedType.Function); } diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs new file mode 100644 index 00000000000..179cc065f03 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Runtime.InteropServices.JavaScript; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop.JavaScript +{ + internal sealed class JSMarshalAsAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly INamedTypeSymbol _jsMarshalAsAttribute; + + public JSMarshalAsAttributeParser(Compilation compilation) + { + _jsMarshalAsAttribute = compilation.GetTypeByMetadataName(Constants.JSMarshalAsAttribute)!.ConstructUnboundGenericType(); + } + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.IsGenericType && SymbolEqualityComparer.Default.Equals(_jsMarshalAsAttribute, attributeType.ConstructUnboundGenericType()); + public MarshallingInfo ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + JSTypeFlags jsType = JSTypeFlags.None; + List jsTypeArguments = new List(); + INamedTypeSymbol? jsTypeArgs = attributeData.AttributeClass.TypeArguments[0] as INamedTypeSymbol; + if (jsTypeArgs.IsGenericType) + { + string gt = jsTypeArgs.ConstructUnboundGenericType().ToDisplayString(); + string name = gt.Substring(gt.IndexOf("JSType") + "JSType.".Length); + name = name.Substring(0, name.IndexOf("<")); + + Enum.TryParse(name, out jsType); + + foreach (var ta in jsTypeArgs.TypeArguments.Cast().Select(x => x.ToDisplayString())) + { + string argName = ta.Substring(ta.IndexOf("JSType") + "JSType.".Length); + JSTypeFlags jsTypeArg = JSTypeFlags.None; + Enum.TryParse(argName, out jsTypeArg); + jsTypeArguments.Add(jsTypeArg); + } + } + else + { + string st = jsTypeArgs.ToDisplayString(); + string name = st.Substring(st.IndexOf("JSType") + "JSType.".Length); + Enum.TryParse(name, out jsType); + } + + if (jsType == JSTypeFlags.None) + { + return new JSMissingMarshallingInfo(JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)); + } + + return new JSMarshallingInfo(NoMarshallingInfo.Instance, JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)) + { + JSType = jsType, + JSTypeArguments = jsTypeArguments.ToArray(), + }; + } + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + return new UseSiteAttributeData(0, NoCountInfo.Instance, attributeData); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs deleted file mode 100644 index b6fdaa576b5..00000000000 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs +++ /dev/null @@ -1,86 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices.JavaScript; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Interop.JavaScript -{ - public sealed class JSMarshallingAttributeInfoParser - { - private readonly ITypeSymbol _jsMarshalAsAttribute; - private readonly ITypeSymbol _marshalUsingAttribute; - - public JSMarshallingAttributeInfoParser( - Compilation compilation, - IGeneratorDiagnostics diagnostics, - ISymbol contextSymbol) - { - _jsMarshalAsAttribute = compilation.GetTypeByMetadataName(Constants.JSMarshalAsAttribute)!.ConstructUnboundGenericType(); - _marshalUsingAttribute = compilation.GetTypeByMetadataName(Constants.MarshalUsingAttribute)!; - } - - public MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes, - MarshallingInfo inner) - { - JSTypeFlags jsType = JSTypeFlags.None; - List jsTypeArguments = new List(); - - foreach (AttributeData useSiteAttribute in useSiteAttributes) - { - INamedTypeSymbol attributeClass = useSiteAttribute.AttributeClass!; - if (attributeClass.IsGenericType && SymbolEqualityComparer.Default.Equals(_jsMarshalAsAttribute, attributeClass.ConstructUnboundGenericType())) - { - INamedTypeSymbol? jsTypeArgs = attributeClass.TypeArguments[0] as INamedTypeSymbol; - if (jsTypeArgs.IsGenericType) - { - string gt = jsTypeArgs.ConstructUnboundGenericType().ToDisplayString(); - string name = gt.Substring(gt.IndexOf("JSType") + "JSType.".Length); - name = name.Substring(0, name.IndexOf("<")); - - Enum.TryParse(name, out jsType); - - foreach (var ta in jsTypeArgs.TypeArguments.Cast().Select(x => x.ToDisplayString())) - { - string argName = ta.Substring(ta.IndexOf("JSType") + "JSType.".Length); - JSTypeFlags jsTypeArg = JSTypeFlags.None; - Enum.TryParse(argName, out jsTypeArg); - jsTypeArguments.Add(jsTypeArg); - } - } - else - { - string st = jsTypeArgs.ToDisplayString(); - string name = st.Substring(st.IndexOf("JSType") + "JSType.".Length); - Enum.TryParse(name, out jsType); - } - - } - if (SymbolEqualityComparer.Default.Equals(_marshalUsingAttribute, attributeClass)) - { - return new JSMarshallingInfo(inner) - { - JSType = JSTypeFlags.Array, - JSTypeArguments = Array.Empty(), - }; - } - } - - if (jsType == JSTypeFlags.None) - { - return new JSMissingMarshallingInfo(); - } - - return new JSMarshallingInfo(inner) - { - JSType = jsType, - JSTypeArguments = jsTypeArguments.ToArray(), - }; - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs index d18f67b4792..edeb7961aad 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs @@ -6,26 +6,24 @@ namespace Microsoft.Interop.JavaScript { - internal record JSMarshallingInfo : MarshallingInfo + internal record JSMarshallingInfo(MarshallingInfo Inner, JSTypeInfo TypeInfo) : MarshallingInfo { - public MarshallingInfo Inner; - public JSTypeFlags JSType; - public JSTypeFlags[] JSTypeArguments; - public JSMarshallingInfo(MarshallingInfo inner) - { - Inner = inner; - } protected JSMarshallingInfo() + :this(NoMarshallingInfo.Instance, new JSInvalidTypeInfo()) { Inner = null; } + + public JSTypeFlags JSType { get; init; } + public JSTypeFlags[] JSTypeArguments { get; init; } } internal sealed record JSMissingMarshallingInfo : JSMarshallingInfo { - public JSMissingMarshallingInfo() + public JSMissingMarshallingInfo(JSTypeInfo typeInfo) { JSType = JSTypeFlags.Missing; + TypeInfo = typeInfo; } } } diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs index fed51dde6d6..ecbb1fcaa76 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs @@ -48,9 +48,9 @@ public override IEnumerable Generate(TypePositionInfo info, Stu ? Argument(IdentifierName(context.GetIdentifiers(info).native)) : _inner.AsArgument(info, context); - var jsty = (JSFunctionTypeInfo)info.ManagedType; + var jsty = (JSFunctionTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; var sourceTypes = jsty.ArgsTypeInfo - .Select(a => ParseTypeName(a.FullTypeName)) + .Select(a => a.Syntax) .ToArray(); if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.In && info.IsManagedReturnPosition) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs index 86162f0e62c..df9cee6af62 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs @@ -20,8 +20,8 @@ public TaskJSGenerator(MarshalerType resultMarshalerType) public override IEnumerable GenerateBind(TypePositionInfo info, StubCodeContext context) { - var jsty = (JSTaskTypeInfo)info.ManagedType; - if (jsty.ResultTypeInfo.FullTypeName == "void") + var jsty = (JSTaskTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; + if (jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void)) { yield return InvocationExpression(MarshalerTypeName(MarshalerType.Task), ArgumentList()); } @@ -34,7 +34,7 @@ public override IEnumerable GenerateBind(TypePositionInfo info public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { - var jsty = (JSTaskTypeInfo)info.ManagedType; + var jsty = (JSTaskTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; string argName = context.GetAdditionalIdentifier(info, "js_arg"); var target = info.IsManagedReturnPosition @@ -47,14 +47,14 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.In && info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToManagedMethodVoid(target, source) : ToManagedMethod(target, source, jsty.ResultTypeInfo.Syntax); } if (context.CurrentStage == StubCodeContext.Stage.Marshal && context.Direction == CustomTypeMarshallingDirection.Out && info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToJSMethodVoid(target, source) : ToJSMethod(target, source, jsty.ResultTypeInfo.Syntax); } @@ -66,14 +66,14 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (context.CurrentStage == StubCodeContext.Stage.Invoke && context.Direction == CustomTypeMarshallingDirection.In && !info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToJSMethodVoid(target, source) : ToJSMethod(target, source, jsty.ResultTypeInfo.Syntax); } if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.Out && !info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToManagedMethodVoid(target, source) : ToManagedMethod(target, source, jsty.ResultTypeInfo.Syntax); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs index 204709cf60c..be23c9e3e55 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; @@ -86,7 +87,8 @@ private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbo // later user work. AnyDiagnosticsSink diagnostics = new(); StubEnvironment env = context.Compilation.CreateStubEnvironment(); - SignatureContext targetSignatureContext = SignatureContext.Create(method, CreateInteropAttributeDataFromDllImport(dllImportData), env, diagnostics, typeof(ConvertToLibraryImportAnalyzer).Assembly); + AttributeData dllImportAttribute = method.GetAttributes().First(attr => attr.AttributeClass.ToDisplayString() == TypeNames.DllImportAttribute); + SignatureContext targetSignatureContext = SignatureContext.Create(method, CreateInteropAttributeDataFromDllImport(dllImportData), env, diagnostics, dllImportAttribute, typeof(ConvertToLibraryImportAnalyzer).Assembly); var generatorFactoryKey = LibraryImportGeneratorHelpers.CreateGeneratorFactory(env, new LibraryImportGeneratorOptions(context.Options.AnalyzerConfigOptionsProvider.GlobalOptions)); diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs index 1da4dbd49af..f74db2f8d84 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs @@ -312,7 +312,7 @@ private static SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenLis } // Create the stub. - var signatureContext = SignatureContext.Create(symbol, libraryImportData, environment, generatorDiagnostics, typeof(LibraryImportGenerator).Assembly); + var signatureContext = SignatureContext.Create(symbol, libraryImportData, environment, generatorDiagnostics, generatedDllImportAttr, typeof(LibraryImportGenerator).Assembly); var containingTypeContext = new ContainingSyntaxContext(originalSyntax); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs new file mode 100644 index 00000000000..2c58acff79b --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for single-dimensional zero-based array types using the System.Runtime.InteropServices.Marshalling.ArrayMarshaller and System.Runtime.InteropServices.Marshalling.PointerArrayMarshaller + /// built-in types. + /// + public sealed class ArrayMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + + public ArrayMarshallingInfoProvider(Compilation compilation) + { + _compilation = compilation; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type is IArrayTypeSymbol { IsSZArray: true }; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + ITypeSymbol elementType = ((IArrayTypeSymbol)type).ElementType; + return CreateArrayMarshallingInfo(_compilation, type, elementType, countInfo, marshallingInfoCallback(elementType, useSiteAttributes, indirectionDepth + 1)); + } + + public static MarshallingInfo CreateArrayMarshallingInfo( + Compilation compilation, + ITypeSymbol managedType, + ITypeSymbol elementType, + CountInfo countInfo, + MarshallingInfo elementMarshallingInfo) + { + ITypeSymbol typeArgumentToInsert = elementType; + INamedTypeSymbol? arrayMarshaller; + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_PointerArrayMarshaller_Metadata); + typeArgumentToInsert = pointedAt; + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_ArrayMarshaller_Metadata); + } + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. + return new MissingSupportCollectionMarshallingInfo(countInfo, elementMarshallingInfo); + } + + if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(arrayMarshaller) + && ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(arrayMarshaller)) + { + arrayMarshaller = arrayMarshaller.Construct( + typeArgumentToInsert, + arrayMarshaller.TypeArguments.Last()); + + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => elementMarshallingInfo; + if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(arrayMarshaller, managedType, compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers)) + { + return new NativeLinearCollectionMarshallingInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), + marshallers.Value, + countInfo, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller.TypeParameters.Last())); + } + } + + Debug.WriteLine("Default marshallers for arrays should be a valid shape."); + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs new file mode 100644 index 00000000000..cc3a1996957 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for unmanaged types that may be blittable. + /// + public sealed class BlittableTypeMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + + public BlittableTypeMarshallingInfoProvider(Compilation compilation) + { + _compilation = compilation; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type is INamedTypeSymbol { IsUnmanagedType: true } unmanagedType + && unmanagedType.IsConsideredBlittable(); + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + if (type.TypeKind is TypeKind.Enum or TypeKind.Pointer or TypeKind.FunctionPointer + || type.SpecialType.IsAlwaysBlittable()) + { + // Treat primitive types and enums as having no marshalling info. + // They are supported in configurations where runtime marshalling is enabled. + return NoMarshallingInfo.Instance; + } + else if (_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute) is null) + { + // If runtime marshalling cannot be disabled, then treat this as a "missing support" scenario so we can gracefully fall back to using the forwarder downlevel. + return new MissingSupportMarshallingInfo(); + } + else + { + return new UnmanagedBlittableMarshallingInfo(type.IsStrictlyBlittable()); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs new file mode 100644 index 00000000000..6a7b78acf1d --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for bool elements without any marshalling information. + /// + public sealed class BooleanMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_Boolean; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + // We intentionally don't support marshalling bool with no marshalling info + // as treating bool as a non-normalized 1-byte value is generally not a good default. + // Additionally, that default is different than the runtime marshalling, so by explicitly + // blocking bool marshalling without additional info, we make it a little easier + // to transition by explicitly notifying people of changing behavior. + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs new file mode 100644 index 00000000000..b8b6f4c4c16 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for char elements without any marshalling information on the element itself. + /// + public sealed class CharMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly DefaultMarshallingInfo _defaultMarshallingInfo; + + public CharMarshallingInfoProvider(DefaultMarshallingInfo defaultMarshallingInfo) + { + _defaultMarshallingInfo = defaultMarshallingInfo; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_Char; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + // No marshalling info was computed, but a character encoding was provided. + // If the type is a character then pass on these details. + return _defaultMarshallingInfo.CharEncoding == CharEncoding.Undefined ? new UnmanagedBlittableMarshallingInfo(IsStrictlyBlittable: false) : new MarshallingInfoStringSupport(_defaultMarshallingInfo.CharEncoding); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs new file mode 100644 index 00000000000..d1df5d4e19e --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + internal static class CustomMarshallingInfoHelper + { + public static MarshallingInfo CreateNativeMarshallingInfo( + ITypeSymbol type, + INamedTypeSymbol entryPointType, + AttributeData attrData, + UseSiteAttributeProvider useSiteAttributeProvider, + GetMarshallingInfoCallback getMarshallingInfoCallback, + int indirectionDepth, + CountInfo parsedCountInfo, + IGeneratorDiagnostics diagnostics, + Compilation compilation) + { + if (!ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) + { + return NoMarshallingInfo.Instance; + } + + if (!(entryPointType.IsStatic && entryPointType.TypeKind == TypeKind.Class) + && entryPointType.TypeKind != TypeKind.Struct) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerTypeMustBeStaticClassOrStruct), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + ManagedTypeInfo entryPointTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType); + + bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); + if (isLinearCollectionMarshalling) + { + // Update the entry point type with the type arguments based on the managed type + if (type is IArrayTypeSymbol arrayManagedType) + { + // Generally, we require linear collection marshallers to have "arity of managed type + 1" arity. + // However, arrays aren't "generic" over their element type as they're generics, but we want to treat the element type + // as a generic type parameter. As a result, we require an arity of 2 for array marshallers, 1 for the array element type, + // and 1 for the native element type (the required additional type parameter for linear collection marshallers). + if (entryPointType.Arity != 2) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + entryPointType = entryPointType.ConstructedFrom.Construct( + arrayManagedType.ElementType, + entryPointType.TypeArguments.Last()); + } + else if (type is INamedTypeSymbol namedManagedCollectionType && entryPointType.IsUnboundGenericType) + { + if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( + namedManagedCollectionType, + entryPointType, + isLinearCollectionMarshalling, + (type, entryPointType) => diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), + out ITypeSymbol resolvedEntryPointType)) + { + return NoMarshallingInfo.Instance; + } + + entryPointType = (INamedTypeSymbol)resolvedEntryPointType; + } + else + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => getMarshallingInfoCallback(elementType, useSiteAttributeProvider, indirectionDepth + 1); + if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers)) + { + return new NativeLinearCollectionMarshallingInfo( + entryPointTypeInfo, + collectionMarshallers.Value, + parsedCountInfo, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType.TypeParameters.Last())); + } + return NoMarshallingInfo.Instance; + } + + if (type is INamedTypeSymbol namedManagedType && entryPointType.IsUnboundGenericType) + { + if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( + namedManagedType, + entryPointType, + isLinearCollectionMarshalling, + (type, entryPointType) => diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), + out ITypeSymbol resolvedEntryPointType)) + { + return NoMarshallingInfo.Instance; + } + + entryPointType = (INamedTypeSymbol)resolvedEntryPointType; + } + + if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(entryPointType, type, compilation, out CustomTypeMarshallers? marshallers)) + { + return new NativeMarshallingAttributeInfo(entryPointTypeInfo, marshallers.Value); + } + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs new file mode 100644 index 00000000000..4294f11610f --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Simple User-application of System.Runtime.InteropServices.MarshalAsAttribute + /// + public sealed record MarshalAsInfo( + UnmanagedType UnmanagedType, + CharEncoding CharEncoding) : MarshallingInfoStringSupport(CharEncoding) + { + // UnmanagedType.LPUTF8Str is not in netstandard2.0, so we define a constant for the value here. + // See https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.unmanagedtype + internal const UnmanagedType UnmanagedType_LPUTF8Str = (UnmanagedType)0x30; + } + + /// + /// This class suppports parsing a System.Runtime.InteropServices.MarshalAsAttribute. + /// + public sealed class MarshalAsAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + private readonly DefaultMarshallingInfo _defaultInfo; + + public MarshalAsAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics, DefaultMarshallingInfo defaultInfo) + { + _compilation = compilation; + _diagnostics = diagnostics; + _defaultInfo = defaultInfo; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute; + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + ImmutableDictionary namedArguments = ImmutableDictionary.CreateRange(attributeData.NamedArguments); + SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; + if (namedArguments.TryGetValue(nameof(MarshalAsAttribute.SizeConst), out TypedConstant sizeConstArg)) + { + arraySizeInfo = arraySizeInfo with { ConstSize = (int)sizeConstArg.Value! }; + } + if (namedArguments.TryGetValue(nameof(MarshalAsAttribute.SizeParamIndex), out TypedConstant sizeParamIndexArg)) + { + if (!elementInfoProvider.TryGetInfoForParamIndex(attributeData, (short)sizeParamIndexArg.Value!, marshallingInfoCallback, out TypePositionInfo paramIndexInfo)) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(MarshalAsAttribute.SizeParamIndex), sizeParamIndexArg.Value.ToString()); + } + arraySizeInfo = arraySizeInfo with { ParamAtIndex = paramIndexInfo }; + } + return new UseSiteAttributeData(0, arraySizeInfo, attributeData); + } + + MarshallingInfo? IMarshallingInfoAttributeParser.ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + object unmanagedTypeObj = attributeData.ConstructorArguments[0].Value!; + UnmanagedType unmanagedType = unmanagedTypeObj is short unmanagedTypeAsShort + ? (UnmanagedType)unmanagedTypeAsShort + : (UnmanagedType)unmanagedTypeObj; + if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) + || unmanagedType == UnmanagedType.CustomMarshaler + || unmanagedType == UnmanagedType.SafeArray) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(UnmanagedType), unmanagedType.ToString()); + } + + bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; + UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize; + + // All other data on attribute is defined as NamedArguments. + foreach (KeyValuePair namedArg in attributeData.NamedArguments) + { + switch (namedArg.Key) + { + case nameof(MarshalAsAttribute.SafeArraySubType): + case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): + case nameof(MarshalAsAttribute.IidParameterIndex): + case nameof(MarshalAsAttribute.MarshalTypeRef): + case nameof(MarshalAsAttribute.MarshalType): + case nameof(MarshalAsAttribute.MarshalCookie): + _diagnostics.ReportConfigurationNotSupported(attributeData, $"{attributeData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + break; + case nameof(MarshalAsAttribute.ArraySubType): + if (!isArrayType) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, $"{attributeData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + } + elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; + break; + } + } + + if (isArrayType) + { + if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(UnmanagedType), unmanagedType.ToString()); + return NoMarshallingInfo.Instance; + } + + MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; + if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) + { + elementMarshallingInfo = elementType.SpecialType == SpecialType.System_String + ? CreateStringMarshallingInfo(elementType, elementUnmanagedType) + : new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); + } + else + { + elementMarshallingInfo = marshallingInfoCallback(elementType, useSiteAttributes, indirectionDepth + 1); + } + + CountInfo countInfo = NoCountInfo.Instance; + + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteAttributeData)) + { + countInfo = useSiteAttributeData.CountInfo; + } + + return ArrayMarshallingInfoProvider.CreateArrayMarshallingInfo(_compilation, type, elementType, countInfo, elementMarshallingInfo); + } + + if (type.SpecialType == SpecialType.System_String) + { + return CreateStringMarshallingInfo(type, unmanagedType); + } + + return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); + } + + private MarshallingInfo CreateStringMarshallingInfo( + ITypeSymbol type, + UnmanagedType unmanagedType) + { + string? marshallerName = unmanagedType switch + { + UnmanagedType.BStr => TypeNames.BStrStringMarshaller, + UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller, + UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller, + MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller, + _ => null + }; + + if (marshallerName is null) + return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); + + return StringMarshallingInfoProvider.CreateStringMarshallingInfo(_compilation, type, marshallerName); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs new file mode 100644 index 00000000000..14d60d3c9ce --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// This class suppports parsing a System.Runtime.InteropServices.Marshalling.MarshalUsingAttribute. + /// + public sealed class MarshalUsingAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + + public MarshalUsingAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics) + { + _compilation = compilation; + _diagnostics = diagnostics; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.MarshalUsingAttribute; + + MarshallingInfo? IMarshallingInfoAttributeParser.ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + Debug.Assert(attributeData.AttributeClass!.ToDisplayString() == TypeNames.MarshalUsingAttribute); + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + if (attributeData.ConstructorArguments.Length == 0) + { + // This attribute only has count information. + // It does not provide any marshalling info. + // Return null here to respresent the lack of any marshalling info, + // instead of the presence of invalid marshalling info. + return null; + } + + if (attributeData.ConstructorArguments[0].Value is not INamedTypeSymbol namedType) + { + return NoMarshallingInfo.Instance; + } + + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + namedType, + attributeData, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation + ); + } + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + ImmutableDictionary namedArgs = ImmutableDictionary.CreateRange(attributeData.NamedArguments); + CountInfo countInfo = ParseCountInfo(attributeData, namedArgs, elementInfoProvider, marshallingInfoCallback); + int elementIndirectionDepth = namedArgs.TryGetValue(ManualTypeMarshallingHelper.MarshalUsingProperties.ElementIndirectionDepth, out TypedConstant value) ? (int)value.Value! : 0; + return new UseSiteAttributeData(elementIndirectionDepth, countInfo, attributeData); + } + + private CountInfo ParseCountInfo(AttributeData attributeData, ImmutableDictionary namedArguments, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + int? constSize = null; + string? elementName = null; + foreach (KeyValuePair arg in attributeData.NamedArguments) + { + if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ConstantElementCount) + { + constSize = (int)arg.Value.Value!; + } + else if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName) + { + if (arg.Value.Value is null) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, "null"); + return NoCountInfo.Instance; + } + elementName = (string)arg.Value.Value!; + } + } + + if (constSize is not null && elementName is not null) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attributeData, nameof(SR.ConstantAndElementCountInfoDisallowed)); + } + else if (constSize is not null) + { + return new ConstSizeCountInfo(constSize.Value); + } + else if (elementName is not null) + { + if (!elementInfoProvider.TryGetInfoForElementName(attributeData, elementName, marshallingInfoCallback, out TypePositionInfo elementInfo)) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, elementName); + return NoCountInfo.Instance; + } + return new CountElementCountInfo(elementInfo); + } + + return NoCountInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index c245f8db61b..ab849630886 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -281,13 +281,21 @@ public bool AnyIncomingEdge(int to) { if (nestedCollection.ElementCountInfo is CountElementCountInfo { ElementInfo: TypePositionInfo nestedCountElement }) { - yield return nestedCountElement; + // Do not include dependent elements with no managed or native index. + // These values are dummy values that are inserted earlier to avoid emitting extra diagnostics. + if (nestedCountElement.ManagedIndex != TypePositionInfo.UnsetIndex || nestedCountElement.NativeIndex != TypePositionInfo.UnsetIndex) + { + yield return nestedCountElement; + } } foreach (KeyValuePair mode in nestedCollection.Marshallers.Modes) { - foreach (TypePositionInfo nestedElements in GetDependentElementsOfMarshallingInfo(mode.Value.CollectionElementMarshallingInfo)) + foreach (TypePositionInfo nestedElement in GetDependentElementsOfMarshallingInfo(mode.Value.CollectionElementMarshallingInfo)) { - yield return nestedElements; + if (nestedElement.ManagedIndex != TypePositionInfo.UnsetIndex || nestedElement.NativeIndex != TypePositionInfo.UnsetIndex) + { + yield return nestedElement; + } } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 19680588e7b..027267cf373 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -15,7 +15,7 @@ namespace Microsoft.Interop /// Type used to pass on default marshalling details. /// /// - /// This type used to pass default marshalling details to . + /// This type used to pass default marshalling details to the various marshalling info parsers. /// Since it contains a , it should not be used as a field on any types /// derived from . See remarks on . /// @@ -79,18 +79,6 @@ public enum CharEncoding CharEncoding CharEncoding ) : MarshallingInfo; - /// - /// Simple User-application of System.Runtime.InteropServices.MarshalAsAttribute - /// - public sealed record MarshalAsInfo( - UnmanagedType UnmanagedType, - CharEncoding CharEncoding) : MarshallingInfoStringSupport(CharEncoding) - { - // UnmanagedType.LPUTF8Str is not in netstandard2.0, so we define a constant for the value here. - // See https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.unmanagedtype - internal const UnmanagedType UnmanagedType_LPUTF8Str = (UnmanagedType)0x30; - } - /// /// The provided type was determined to be an "unmanaged" type that can be passed as-is to native code. /// @@ -145,11 +133,6 @@ bool IsStrictlyBlittable EntryPointType, Marshallers); - /// - /// The type of the element is a SafeHandle-derived type with no marshalling attributes. - /// - public sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo; - /// /// Marshalling information is lacking because of support not because it is /// unknown or non-existent. Includes information about element types in case @@ -160,720 +143,4 @@ bool IsStrictlyBlittable /// the forwarder marshaller. /// public sealed record MissingSupportCollectionMarshallingInfo(CountInfo CountInfo, MarshallingInfo ElementMarshallingInfo) : MissingSupportMarshallingInfo; - - public sealed class MarshallingAttributeInfoParser - { - private readonly Compilation _compilation; - private readonly IGeneratorDiagnostics _diagnostics; - private readonly DefaultMarshallingInfo _defaultInfo; - private readonly ISymbol _contextSymbol; - private readonly ITypeSymbol _marshalAsAttribute; - private readonly ITypeSymbol _marshalUsingAttribute; - - public MarshallingAttributeInfoParser( - Compilation compilation, - IGeneratorDiagnostics diagnostics, - DefaultMarshallingInfo defaultInfo, - ISymbol contextSymbol) - { - _compilation = compilation; - _diagnostics = diagnostics; - _defaultInfo = defaultInfo; - _contextSymbol = contextSymbol; - _marshalAsAttribute = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)!; - _marshalUsingAttribute = compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute)!; - } - - public MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes) - { - return ParseMarshallingInfo(managedType, useSiteAttributes, ImmutableHashSet.Empty); - } - - private MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes, - ImmutableHashSet inspectedElements) - { - Dictionary marshallingAttributesByIndirectionDepth = new(); - int maxIndirectionLevelDataProvided = 0; - foreach (AttributeData attribute in useSiteAttributes) - { - if (TryGetAttributeIndirectionLevel(attribute, out int indirectionLevel)) - { - if (marshallingAttributesByIndirectionDepth.ContainsKey(indirectionLevel)) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attribute, nameof(SR.DuplicateMarshallingInfo), indirectionLevel.ToString()); - return NoMarshallingInfo.Instance; - } - marshallingAttributesByIndirectionDepth.Add(indirectionLevel, attribute); - maxIndirectionLevelDataProvided = Math.Max(maxIndirectionLevelDataProvided, indirectionLevel); - } - } - - int maxIndirectionDepthUsed = 0; - MarshallingInfo info = GetMarshallingInfo( - managedType, - marshallingAttributesByIndirectionDepth, - indirectionLevel: 0, - inspectedElements, - ref maxIndirectionDepthUsed); - if (maxIndirectionDepthUsed < maxIndirectionLevelDataProvided) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo( - marshallingAttributesByIndirectionDepth[maxIndirectionLevelDataProvided], - nameof(SR.ExtraneousMarshallingInfo), - maxIndirectionLevelDataProvided.ToString(), - maxIndirectionDepthUsed.ToString()); - } - return info; - } - - private MarshallingInfo GetMarshallingInfo( - ITypeSymbol type, - Dictionary useSiteAttributes, - int indirectionLevel, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - maxIndirectionDepthUsed = Math.Max(indirectionLevel, maxIndirectionDepthUsed); - CountInfo parsedCountInfo = NoCountInfo.Instance; - - if (useSiteAttributes.TryGetValue(indirectionLevel, out AttributeData useSiteAttribute)) - { - INamedTypeSymbol attributeClass = useSiteAttribute.AttributeClass!; - - if (indirectionLevel == 0 - && SymbolEqualityComparer.Default.Equals(_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) - { - // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateInfoFromMarshalAs(type, useSiteAttribute, inspectedElements, ref maxIndirectionDepthUsed); - } - else if (SymbolEqualityComparer.Default.Equals(_compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) - { - if (parsedCountInfo != NoCountInfo.Instance) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(useSiteAttribute, nameof(SR.DuplicateCountInfo)); - return NoMarshallingInfo.Instance; - } - parsedCountInfo = CreateCountInfo(useSiteAttribute, inspectedElements); - if (useSiteAttribute.ConstructorArguments.Length != 0) - { - return CreateNativeMarshallingInfo( - type, - (INamedTypeSymbol)useSiteAttribute.ConstructorArguments[0].Value!, - useSiteAttribute, - isMarshalUsingAttribute: true, - indirectionLevel, - parsedCountInfo, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed); - } - } - } - - // If we aren't overriding the marshalling at usage time, - // then fall back to the information on the element type itself. - foreach (AttributeData typeAttribute in type.GetAttributes()) - { - INamedTypeSymbol attributeClass = typeAttribute.AttributeClass!; - - if (attributeClass.ToDisplayString() == TypeNames.NativeMarshallingAttribute) - { - return CreateNativeMarshallingInfo( - type, - (INamedTypeSymbol)typeAttribute.ConstructorArguments[0].Value!, - typeAttribute, - isMarshalUsingAttribute: false, - indirectionLevel, - parsedCountInfo, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed); - } - } - - // If the type doesn't have custom attributes that dictate marshalling, - // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo( - type, - parsedCountInfo, - indirectionLevel, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed, - out MarshallingInfo infoMaybe)) - { - return infoMaybe; - } - - return NoMarshallingInfo.Instance; - } - - private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashSet inspectedElements) - { - int? constSize = null; - string? elementName = null; - foreach (KeyValuePair arg in marshalUsingData.NamedArguments) - { - if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ConstantElementCount) - { - constSize = (int)arg.Value.Value!; - } - else if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName) - { - if (arg.Value.Value is null) - { - _diagnostics.ReportConfigurationNotSupported(marshalUsingData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, "null"); - return NoCountInfo.Instance; - } - elementName = (string)arg.Value.Value!; - } - } - - if (constSize is not null && elementName is not null) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(marshalUsingData, nameof(SR.ConstantAndElementCountInfoDisallowed)); - } - else if (constSize is not null) - { - return new ConstSizeCountInfo(constSize.Value); - } - else if (elementName is not null) - { - if (inspectedElements.Contains(elementName)) - { - throw new CyclicalCountElementInfoException(inspectedElements, elementName); - } - - try - { - TypePositionInfo? elementInfo = CreateForElementName(elementName, inspectedElements.Add(elementName)); - if (elementInfo is null) - { - _diagnostics.ReportConfigurationNotSupported(marshalUsingData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, elementName); - return NoCountInfo.Instance; - } - return new CountElementCountInfo(elementInfo); - } - // Specifically catch the exception when we're trying to inspect the element that started the cycle. - // This ensures that we've unwound the whole cycle so when we return NoCountInfo.Instance, there will be no cycles in the count info. - catch (CyclicalCountElementInfoException ex) when (ex.StartOfCycle == elementName) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(marshalUsingData, nameof(SR.CyclicalCountInfo), elementName); - return NoCountInfo.Instance; - } - } - - return NoCountInfo.Instance; - } - - private TypePositionInfo? CreateForParamIndex(AttributeData attrData, int paramIndex, ImmutableHashSet inspectedElements) - { - if (!(_contextSymbol is IMethodSymbol method && 0 <= paramIndex && paramIndex < method.Parameters.Length)) - { - return null; - } - IParameterSymbol param = method.Parameters[paramIndex]; - - if (inspectedElements.Contains(param.Name)) - { - throw new CyclicalCountElementInfoException(inspectedElements, param.Name); - } - - try - { - return TypePositionInfo.CreateForParameter( - param, - ParseMarshallingInfo(param.Type, param.GetAttributes(), inspectedElements.Add(param.Name)), _compilation) with - { ManagedIndex = paramIndex }; - } - // Specifically catch the exception when we're trying to inspect the element that started the cycle. - // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. - catch (CyclicalCountElementInfoException ex) when (ex.StartOfCycle == param.Name) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), param.Name); - return SizeAndParamIndexInfo.UnspecifiedParam; - } - } - - private TypePositionInfo? CreateForElementName(string elementName, ImmutableHashSet inspectedElements) - { - if (_contextSymbol is IMethodSymbol method) - { - if (elementName == CountElementCountInfo.ReturnValueElementName) - { - return new TypePositionInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), - ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), inspectedElements)) with - { - ManagedIndex = TypePositionInfo.ReturnIndex - }; - } - - for (int i = 0; i < method.Parameters.Length; i++) - { - IParameterSymbol param = method.Parameters[i]; - if (param.Name == elementName) - { - return TypePositionInfo.CreateForParameter(param, ParseMarshallingInfo(param.Type, param.GetAttributes(), inspectedElements), _compilation) with { ManagedIndex = i }; - } - } - } - else if (_contextSymbol is INamedTypeSymbol) - { - // TODO: Handle when we create a struct marshalling generator - // Do we want to support CountElementName pointing to only fields, or properties as well? - // If only fields, how do we handle properties with generated backing fields? - } - - return null; - } - - private MarshallingInfo CreateInfoFromMarshalAs( - ITypeSymbol type, - AttributeData attrData, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; - UnmanagedType unmanagedType = unmanagedTypeObj is short unmanagedTypeAsShort - ? (UnmanagedType)unmanagedTypeAsShort - : (UnmanagedType)unmanagedTypeObj; - if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) - || unmanagedType == UnmanagedType.CustomMarshaler - || unmanagedType == UnmanagedType.SafeArray) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - } - - bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; - UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize; - SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; - - // All other data on attribute is defined as NamedArguments. - foreach (KeyValuePair namedArg in attrData.NamedArguments) - { - switch (namedArg.Key) - { - default: - Debug.Fail($"An unknown member was found on {nameof(MarshalAsAttribute)}"); - continue; - case nameof(MarshalAsAttribute.SafeArraySubType): - case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): - case nameof(MarshalAsAttribute.IidParameterIndex): - case nameof(MarshalAsAttribute.MarshalTypeRef): - case nameof(MarshalAsAttribute.MarshalType): - case nameof(MarshalAsAttribute.MarshalCookie): - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - break; - case nameof(MarshalAsAttribute.ArraySubType): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; - break; - case nameof(MarshalAsAttribute.SizeConst): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - arraySizeInfo = arraySizeInfo with { ConstSize = (int)namedArg.Value.Value! }; - break; - case nameof(MarshalAsAttribute.SizeParamIndex): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - TypePositionInfo? paramIndexInfo = CreateForParamIndex(attrData, (short)namedArg.Value.Value!, inspectedElements); - - if (paramIndexInfo is null) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(MarshalAsAttribute.SizeParamIndex), namedArg.Value.Value.ToString()); - } - arraySizeInfo = arraySizeInfo with { ParamAtIndex = paramIndexInfo }; - break; - } - } - - if (isArrayType) - { - if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - return NoMarshallingInfo.Instance; - } - - MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; - if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) - { - elementMarshallingInfo = elementType.SpecialType == SpecialType.System_String - ? CreateStringMarshallingInfo(elementType, elementUnmanagedType) - : new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); - } - else - { - maxIndirectionDepthUsed = 1; - elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsed); - } - - return CreateArrayMarshallingInfo(type, elementType, arraySizeInfo, elementMarshallingInfo); - } - - if (type.SpecialType == SpecialType.System_String) - { - return CreateStringMarshallingInfo(type, unmanagedType); - } - - return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - } - - private MarshallingInfo CreateNativeMarshallingInfo( - ITypeSymbol type, - INamedTypeSymbol entryPointType, - AttributeData attrData, - bool isMarshalUsingAttribute, - int indirectionLevel, - CountInfo parsedCountInfo, - Dictionary useSiteAttributes, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - if (!ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) - { - return NoMarshallingInfo.Instance; - } - - if (!(entryPointType.IsStatic && entryPointType.TypeKind == TypeKind.Class) - && entryPointType.TypeKind != TypeKind.Struct) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerTypeMustBeStaticClassOrStruct), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - ManagedTypeInfo entryPointTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType); - - bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); - if (isLinearCollectionMarshalling) - { - // Update the entry point type with the type arguments based on the managed type - if (type is IArrayTypeSymbol arrayManagedType) - { - if (entryPointType.Arity != 2) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - entryPointType = entryPointType.ConstructedFrom.Construct( - arrayManagedType.ElementType, - entryPointType.TypeArguments.Last()); - } - else if (type is INamedTypeSymbol namedManagedCollectionType && entryPointType.IsUnboundGenericType) - { - if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( - namedManagedCollectionType, - entryPointType, - isLinearCollectionMarshalling, - (type, entryPointType) => _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), - out ITypeSymbol resolvedEntryPointType)) - { - return NoMarshallingInfo.Instance; - } - - entryPointType = (INamedTypeSymbol)resolvedEntryPointType; - } - else - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed; - Func getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsedLocal); - if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers)) - { - maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal; - return new NativeLinearCollectionMarshallingInfo( - entryPointTypeInfo, - collectionMarshallers.Value, - parsedCountInfo, - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType.TypeParameters.Last())); - } - return NoMarshallingInfo.Instance; - } - - if (type is INamedTypeSymbol namedManagedType && entryPointType.IsUnboundGenericType) - { - if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( - namedManagedType, - entryPointType, - isLinearCollectionMarshalling, - (type, entryPointType) => _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), - out ITypeSymbol resolvedEntryPointType)) - { - return NoMarshallingInfo.Instance; - } - - entryPointType = (INamedTypeSymbol)resolvedEntryPointType; - } - - if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(entryPointType, type, _compilation, out CustomTypeMarshallers? marshallers)) - { - return new NativeMarshallingAttributeInfo(entryPointTypeInfo, marshallers.Value); - } - return NoMarshallingInfo.Instance; - } - - private bool TryCreateTypeBasedMarshallingInfo( - ITypeSymbol type, - CountInfo parsedCountInfo, - int indirectionLevel, - Dictionary useSiteAttributes, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed, - out MarshallingInfo marshallingInfo) - { - // Check for an implicit SafeHandle conversion. - // The SafeHandle type might not be defined if we're using one of the test CoreLib implementations used for NativeAOT. - ITypeSymbol? safeHandleType = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle); - if (safeHandleType is not null) - { - CodeAnalysis.Operations.CommonConversion conversion = _compilation.ClassifyCommonConversion(type, safeHandleType); - if (conversion.Exists - && conversion.IsImplicit - && (conversion.IsReference || conversion.IsIdentity)) - { - bool hasAccessibleDefaultConstructor = false; - if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) - { - foreach (IMethodSymbol ctor in named.InstanceConstructors) - { - if (ctor.Parameters.Length == 0) - { - hasAccessibleDefaultConstructor = _compilation.IsSymbolAccessibleWithin(ctor, _contextSymbol.ContainingType); - break; - } - } - } - marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract); - return true; - } - } - - if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - MarshallingInfo elementMarshallingInfo = GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed); - marshallingInfo = CreateArrayMarshallingInfo(type, elementType, parsedCountInfo, elementMarshallingInfo); - return true; - } - - // No marshalling info was computed, but a character encoding was provided. - // If the type is a character or string then pass on these details. - if (type.SpecialType == SpecialType.System_Char && _defaultInfo.CharEncoding != CharEncoding.Undefined) - { - marshallingInfo = new MarshallingInfoStringSupport(_defaultInfo.CharEncoding); - return true; - } - - if (type.SpecialType == SpecialType.System_String && _defaultInfo.CharEncoding != CharEncoding.Undefined) - { - if (_defaultInfo.CharEncoding == CharEncoding.Custom) - { - if (_defaultInfo.StringMarshallingCustomType is not null) - { - AttributeData attrData = _contextSymbol is IMethodSymbol - ? _contextSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass.ToDisplayString() == TypeNames.LibraryImportAttribute) - : default; - marshallingInfo = CreateNativeMarshallingInfo(type, _defaultInfo.StringMarshallingCustomType, attrData, true, indirectionLevel, parsedCountInfo, useSiteAttributes, inspectedElements, ref maxIndirectionDepthUsed); - return true; - } - } - else - { - marshallingInfo = _defaultInfo.CharEncoding switch - { - CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller), - CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller), - _ => throw new InvalidOperationException() - }; - - return true; - } - - marshallingInfo = new MarshallingInfoStringSupport(_defaultInfo.CharEncoding); - return true; - } - - - if (type.SpecialType == SpecialType.System_Boolean) - { - // We explicitly don't support marshalling bool without any marshalling info - // as treating bool as a non-normalized 1-byte value is generally not a good default. - // Additionally, that default is different than the runtime marshalling, so by explicitly - // blocking bool marshalling without additional info, we make it a little easier - // to transition by explicitly notifying people of changing behavior. - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - - if (type is INamedTypeSymbol { IsUnmanagedType: true } unmanagedType - && unmanagedType.IsConsideredBlittable()) - { - marshallingInfo = GetBlittableMarshallingInfo(type); - return true; - } - - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - - private MarshallingInfo CreateArrayMarshallingInfo( - ITypeSymbol managedType, - ITypeSymbol elementType, - CountInfo countInfo, - MarshallingInfo elementMarshallingInfo) - { - ITypeSymbol typeArgumentToInsert = elementType; - INamedTypeSymbol? arrayMarshaller; - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_PointerArrayMarshaller_Metadata); - typeArgumentToInsert = pointedAt; - } - else - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_ArrayMarshaller_Metadata); - } - - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. - return new MissingSupportCollectionMarshallingInfo(countInfo, elementMarshallingInfo); - } - - if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(arrayMarshaller) - && ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(arrayMarshaller)) - { - arrayMarshaller = arrayMarshaller.Construct( - typeArgumentToInsert, - arrayMarshaller.TypeArguments.Last()); - - Func getMarshallingInfoForElement = (ITypeSymbol elementType) => elementMarshallingInfo; - if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(arrayMarshaller, managedType, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers)) - { - return new NativeLinearCollectionMarshallingInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), - marshallers.Value, - countInfo, - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller.TypeParameters.Last())); - } - } - - Debug.WriteLine("Default marshallers for arrays should be a valid shape."); - return NoMarshallingInfo.Instance; - } - - private MarshallingInfo CreateStringMarshallingInfo( - ITypeSymbol type, - UnmanagedType unmanagedType) - { - string? marshallerName = unmanagedType switch - { - UnmanagedType.BStr => TypeNames.BStrStringMarshaller, - UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller, - UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller, - MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller, - _ => null - }; - - if (marshallerName is null) - return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - - return CreateStringMarshallingInfo(type, marshallerName); - } - - private MarshallingInfo CreateStringMarshallingInfo( - ITypeSymbol type, - string marshallerName) - { - INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName); - if (stringMarshaller is null) - return new MissingSupportMarshallingInfo(); - - if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(stringMarshaller)) - { - if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(stringMarshaller, type, _compilation, out CustomTypeMarshallers? marshallers)) - { - return new NativeMarshallingAttributeInfo( - EntryPointType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(stringMarshaller), - Marshallers: marshallers.Value); - } - } - - return new MissingSupportMarshallingInfo(); - } - - private MarshallingInfo GetBlittableMarshallingInfo(ITypeSymbol type) - { - if (type.TypeKind is TypeKind.Enum or TypeKind.Pointer or TypeKind.FunctionPointer - || type.SpecialType.IsAlwaysBlittable()) - { - // Treat primitive types and enums as having no marshalling info. - // They are supported in configurations where runtime marshalling is enabled. - return NoMarshallingInfo.Instance; - } - else if (_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute) is null) - { - // If runtime marshalling cannot be disabled, then treat this as a "missing support" scenario so we can gracefully fall back to using the forwarder downlevel. - return new MissingSupportMarshallingInfo(); - } - else - { - return new UnmanagedBlittableMarshallingInfo(type.IsStrictlyBlittable()); - } - } - - private bool TryGetAttributeIndirectionLevel(AttributeData attrData, out int indirectionLevel) - { - if (SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, _marshalAsAttribute)) - { - indirectionLevel = 0; - return true; - } - - if (!SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, _marshalUsingAttribute)) - { - indirectionLevel = 0; - return false; - } - - foreach (KeyValuePair arg in attrData.NamedArguments) - { - if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ElementIndirectionDepth) - { - indirectionLevel = (int)arg.Value.Value!; - return true; - } - } - indirectionLevel = 0; - return true; - } - - private sealed class CyclicalCountElementInfoException : Exception - { - public CyclicalCountElementInfoException(ImmutableHashSet elementsInCycle, string startOfCycle) - { - ElementsInCycle = elementsInCycle; - StartOfCycle = startOfCycle; - } - - public ImmutableHashSet ElementsInCycle { get; } - - public string StartOfCycle { get; } - } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs new file mode 100644 index 00000000000..410166893a5 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs @@ -0,0 +1,367 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Information from a marshalling attribute that can only be provided if the attribute is at the usage site. + /// + /// The indirection depth that the info applies to. + /// Any collection count information provided. + /// The original attribute data. + public sealed record UseSiteAttributeData(int IndirectionDepth, CountInfo CountInfo, AttributeData AttributeData); + + /// + /// A callback to get the marshalling info for a given type at the provided indirection depth with the provided attributes at its usage site. + /// + /// The managed type to get marshalling info for + /// The attributes at the use site + /// The target indirection level + /// Marshalling info for provided information. + public delegate MarshallingInfo GetMarshallingInfoCallback(ITypeSymbol type, UseSiteAttributeProvider useSiteAttributes, int indirectionDepth); + + /// + /// A parser for an attribute used at the marshalling site, such as a parameter or return value attribute. + /// + public interface IUseSiteAttributeParser + { + /// + /// Whether or not the parser can parse an attribute of the provided type. + /// + /// The attribute type + /// true if the parser can parse an attribute of the provided type; otherwise false + bool CanParseAttributeType(INamedTypeSymbol attributeType); + + /// + /// Parse the use-site information out of the provided attribute. + /// + /// The attribute data to parse. + /// The provider for information about other elements. This is used to retrieve information about other parameters that might be referenced by any count information. + /// A callback to provide to the when retrieving additional information. + /// The information about the attribute at the use site. + UseSiteAttributeData ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A parser for an attribute that provides information about which marshaller to use. + /// + public interface IMarshallingInfoAttributeParser + { + /// + /// Whether or not the parser can parse an attribute of the provided type. + /// + /// The attribute type + /// true if the parser can parse an attribute of the provided type; otherwise false + bool CanParseAttributeType(INamedTypeSymbol attributeType); + + /// + /// Parse the attribute into marshalling information + /// + /// The attribute to parse + /// The managed type + /// The current indirection depth + /// Attributes provided at the usage site, such as for count information + /// A callback to get marshalling info for nested elements, in the case of a collection of collections. + /// Marshalling information parsed from the attribute, or null if no information could be parsed from the attribute + MarshallingInfo? ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A provider of marshalling info based only on the managed type any any previously parsed use-site attribute information + /// + public interface ITypeBasedMarshallingInfoProvider + { + /// + /// Whether or not the provider can provide marshalling info for the given managed type. + /// + /// The managed type + /// true if the provider can provide info for the provided type; otherwise false + bool CanProvideMarshallingInfoForType(ITypeSymbol type); + /// + /// Get marshalling info for the provided type at the given indirection level. + /// + /// The managed type + /// The current indirection depth + /// Attributes provided at the usage site, such as for count information + /// A callback to get marshalling info for nested elements, in the case of a collection of collections. + /// Marshalling information for the provided type + MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A provider of TypePositionInfo instances for other elements in the current signature, such as parameters or return values. + /// + public interface IElementInfoProvider + { + /// + /// Get the name for a given index. + /// + /// The index + /// The name associated with the provided index, or if the index does not correspond to an element. + string FindNameForParamIndex(int paramIndex); + /// + /// Get a instance for the given element name. + /// + /// The attribute to report diagnostics on. + /// The element name to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The to pass to the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info); + /// + /// Get a instance for the given element index. + /// + /// The attribute to report diagnostics on. + /// The element index to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The to pass to the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info); + } + + /// + /// Convenience extension methods for . + /// + public static class ElementInfoProviderExtensions + { + /// + /// Get a instance for the given element name. + /// + /// The attribute to report diagnostics on. + /// The element name to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + public static bool TryGetInfoForElementName(this IElementInfoProvider provider, AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, out TypePositionInfo info) + { + return provider.TryGetInfoForElementName(attrData, elementName, marshallingInfoCallback, provider, out info); + } + + /// + /// Get a instance for the given element index. + /// + /// The attribute to report diagnostics on. + /// The element index to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + public static bool TryGetInfoForParamIndex(this IElementInfoProvider provider, AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, out TypePositionInfo info) + { + return provider.TryGetInfoForParamIndex(attrData, paramIndex, marshallingInfoCallback, provider, out info); + } + } + + /// + /// Overall parser for all marshalling info. + /// + /// + /// This type combines the provided parsers to enable parsing marshalling information from a single call. + /// For a given managed type and use site attributes, it will parse marshalling information in the following order: + /// + /// Parse attributes provided at the usage site for use-site-only information. + /// Parse attributes provided at the usage site for marshalling information. + /// If no marshalling information has been found yet, parse attributes provided at the definition of the managed type. + /// If no marshalling information has been found yet, generate marshalling information for the managed type itself. + /// + /// + public sealed class MarshallingInfoParser + { + private readonly IGeneratorDiagnostics _diagnostics; + private readonly IElementInfoProvider _elementInfoProvider; + private readonly ImmutableArray _useSiteMarshallingAttributeParsers; + private readonly ImmutableArray _marshallingAttributeParsers; + private readonly ImmutableArray _typeBasedMarshallingInfoProviders; + + /// + /// Construct a new . + /// + /// The diagnostics sink to report all diagnostics to. + /// An to retrieve information about other elements than the current element when parsing. + /// Parsers for retrieving use-site-only information from attributes. + /// Parsers for retrieving marshalling information from attributes and the managed type. + /// Parsers for retrieving marshalling information from the managed type only. + public MarshallingInfoParser( + IGeneratorDiagnostics diagnostics, + IElementInfoProvider elementInfoProvider, + ImmutableArray useSiteMarshallingAttributeParsers, + ImmutableArray marshallingAttributeParsers, + ImmutableArray typeBasedMarshallingInfoProviders) + { + _diagnostics = diagnostics; + // Always support cycle detection. Otherwise we can get stack-overflows, which does not provide a good dev experience for any customer scenario. + _elementInfoProvider = new CycleDetectingElementInfoProvider(elementInfoProvider, diagnostics); + _useSiteMarshallingAttributeParsers = useSiteMarshallingAttributeParsers; + _marshallingAttributeParsers = marshallingAttributeParsers; + _typeBasedMarshallingInfoProviders = typeBasedMarshallingInfoProviders; + } + + /// + /// Parse the marshalling info for the provided managed type and attributes at the usage site. + /// + /// The managed type + /// All attributes specified at the usage site + /// The parsed marshalling information + public MarshallingInfo ParseMarshallingInfo( + ITypeSymbol managedType, + IEnumerable useSiteAttributes) + { + UseSiteAttributeProvider useSiteAttributeProvider = new UseSiteAttributeProvider(_useSiteMarshallingAttributeParsers, useSiteAttributes, _elementInfoProvider, _diagnostics, GetMarshallingInfo); + + MarshallingInfo info = GetMarshallingInfo( + managedType, + useSiteAttributeProvider, + indirectionDepth: 0); + + useSiteAttributeProvider.OnAttributeUsageFinished(); + return info; + } + + private MarshallingInfo GetMarshallingInfo( + ITypeSymbol type, + UseSiteAttributeProvider useSiteAttributes, + int indirectionDepth) + { + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteAttribute)) + { + if (GetMarshallingInfoForAttribute(useSiteAttribute.AttributeData, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo) + { + return marshallingInfo; + } + } + + // If we aren't overriding the marshalling at usage time, + // then fall back to the information on the element type itself. + foreach (AttributeData typeAttribute in type.GetAttributes()) + { + if (GetMarshallingInfoForAttribute(typeAttribute, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo) + { + return marshallingInfo; + } + } + + // If the type doesn't have custom attributes that dictate marshalling, + // then consider the type itself. + return GetMarshallingInfoForType(type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) ?? NoMarshallingInfo.Instance; + } + + private MarshallingInfo? GetMarshallingInfoForAttribute(AttributeData attribute, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + foreach (var parser in _marshallingAttributeParsers) + { + // Automatically ignore invalid attributes. + // The compiler will already error on them. + if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass)) + { + return parser.ParseAttribute(attribute, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } + } + return null; + } + + private MarshallingInfo? GetMarshallingInfoForType(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + foreach (var parser in _typeBasedMarshallingInfoProviders) + { + if (parser.CanProvideMarshallingInfoForType(type)) + { + return parser.GetMarshallingInfo(type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } + } + return null; + } + } + + /// + /// Wraps another with support to detect infinite cycles in marshalling info (i.e. count information that refers to other elements that refer to the original element). + /// + internal sealed class CycleDetectingElementInfoProvider : IElementInfoProvider + { + private ImmutableHashSet _activeInspectingElements = ImmutableHashSet.Empty; + private readonly IElementInfoProvider _innerProvider; + private readonly IGeneratorDiagnostics _diagnostics; + + public CycleDetectingElementInfoProvider(IElementInfoProvider innerProvider, IGeneratorDiagnostics diagnostics) + { + _innerProvider = innerProvider; + _diagnostics = diagnostics; + } + + public string FindNameForParamIndex(int paramIndex) => _innerProvider.FindNameForParamIndex(paramIndex); + public bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, [NotNullWhen(true)] out TypePositionInfo? info) + { + ImmutableHashSet inspectedElements = _activeInspectingElements; + if (inspectedElements.Contains(elementName)) + { + throw new CyclicalElementInfoException(inspectedElements, elementName); + } + try + { + _activeInspectingElements = inspectedElements.Add(elementName); + return _innerProvider.TryGetInfoForElementName(attrData, elementName, marshallingInfoCallback, rootProvider, out info); + } + // Specifically catch the exception when we're trying to inspect the element that started the cycle. + // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. + catch (CyclicalElementInfoException ex) when (ex.StartOfCycle == elementName) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), elementName); + // Create a dummy value for the invalid marshalling. We're already in an error state, so try to not report extraneous diagnostics. + info = new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance); + return true; + } + finally + { + _activeInspectingElements = inspectedElements; + } + } + + public bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, [NotNullWhen(true)] out TypePositionInfo? info) + { + ImmutableHashSet inspectedElements = _activeInspectingElements; + string paramName = _innerProvider.FindNameForParamIndex(paramIndex); + if (paramName is not null && inspectedElements.Contains(paramName)) + { + throw new CyclicalElementInfoException(inspectedElements, paramName); + } + + try + { + _activeInspectingElements = inspectedElements.Add(paramName); + return _innerProvider.TryGetInfoForParamIndex(attrData, paramIndex, marshallingInfoCallback, rootProvider, out info); + } + // Specifically catch the exception when we're trying to inspect the element that started the cycle. + // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. + catch (CyclicalElementInfoException ex) when (ex.StartOfCycle == paramName) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), paramName); + // Create a dummy value for the invalid marshalling. We're already in an error state, so try to not report extraneous diagnostics. + info = new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance); + return true; + } + finally + { + _activeInspectingElements = inspectedElements; + } + } + + private sealed class CyclicalElementInfoException : Exception + { + public CyclicalElementInfoException(ImmutableHashSet elementsInCycle, string startOfCycle) + { + ElementsInCycle = elementsInCycle; + StartOfCycle = startOfCycle; + } + + public ImmutableHashSet ElementsInCycle { get; } + + public string StartOfCycle { get; } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs new file mode 100644 index 00000000000..e00eef6d2c0 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + public sealed class MethodSignatureElementInfoProvider : IElementInfoProvider + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _generatorDiagnostics; + private readonly IMethodSymbol _method; + private readonly ImmutableArray _useSiteAttributeParsers; + + public MethodSignatureElementInfoProvider(Compilation compilation, IGeneratorDiagnostics generatorDiagnostics, IMethodSymbol method, ImmutableArray useSiteAttributeParsers) + { + _compilation = compilation; + _generatorDiagnostics = generatorDiagnostics; + _method = method; + _useSiteAttributeParsers = useSiteAttributeParsers; + } + + public string FindNameForParamIndex(int paramIndex) => paramIndex >= _method.Parameters.Length ? string.Empty : _method.Parameters[paramIndex].Name; + + public bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info) + { + if (elementName == CountElementCountInfo.ReturnValueElementName) + { + info = new TypePositionInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(_method.ReturnType), + marshallingInfoCallback(_method.ReturnType, new UseSiteAttributeProvider(_useSiteAttributeParsers, _method.GetReturnTypeAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0)) with + { + ManagedIndex = TypePositionInfo.ReturnIndex + }; + return true; + } + + for (int i = 0; i < _method.Parameters.Length; i++) + { + IParameterSymbol param = _method.Parameters[i]; + if (param.Name == elementName) + { + info = TypePositionInfo.CreateForParameter( + param, + marshallingInfoCallback(param.Type, new UseSiteAttributeProvider(_useSiteAttributeParsers, param.GetAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0), _compilation) with + { + ManagedIndex = i + }; + return true; + } + } + info = null; + return false; + } + + public bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info) + { + if (paramIndex >= _method.Parameters.Length) + { + info = null; + return false; + } + IParameterSymbol param = _method.Parameters[paramIndex]; + + info = TypePositionInfo.CreateForParameter( + param, + marshallingInfoCallback(param.Type, new UseSiteAttributeProvider(_useSiteAttributeParsers, param.GetAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0), _compilation) with + { + ManagedIndex = paramIndex + }; + return true; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs new file mode 100644 index 00000000000..c5111602914 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + public sealed class NativeMarshallingAttributeParser : IMarshallingInfoAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + + public NativeMarshallingAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics) + { + _compilation = compilation; + _diagnostics = diagnostics; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.NativeMarshallingAttribute; + + public MarshallingInfo? ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + Debug.Assert(attributeData.AttributeClass!.ToDisplayString() == TypeNames.NativeMarshallingAttribute); + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out var useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + if (attributeData.ConstructorArguments[0].Value is not INamedTypeSymbol entryPointType) + { + return NoMarshallingInfo.Instance; + } + + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + entryPointType, + attributeData, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation + ); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs new file mode 100644 index 00000000000..0309793096e --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// The type of the element is a SafeHandle-derived type with no marshalling attributes. + /// + public sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo; + + /// + /// This class supports generating marshalling info for SafeHandle-derived types. + /// + public sealed class SafeHandleMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + private readonly ITypeSymbol _containingScope; + + public SafeHandleMarshallingInfoProvider(Compilation compilation, ITypeSymbol containingScope) + { + _compilation = compilation; + _containingScope = containingScope; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) + { + // Check for an implicit SafeHandle conversion. + // The SafeHandle type might not be defined if we're using one of the test CoreLib implementations used for NativeAOT. + ITypeSymbol? safeHandleType = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle); + if (safeHandleType is not null) + { + CodeAnalysis.Operations.CommonConversion conversion = _compilation.ClassifyCommonConversion(type, safeHandleType); + if (conversion.Exists + && conversion.IsImplicit + && (conversion.IsReference || conversion.IsIdentity)) + { + return true; + } + } + return false; + } + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + bool hasAccessibleDefaultConstructor = false; + if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) + { + foreach (IMethodSymbol ctor in named.InstanceConstructors) + { + if (ctor.Parameters.Length == 0) + { + hasAccessibleDefaultConstructor = _compilation.IsSymbolAccessibleWithin(ctor, _containingScope); + break; + } + } + } + return new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs index da94ec9a541..b9df1831512 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs @@ -56,9 +56,10 @@ public IEnumerable StubParameters InteropAttributeData interopAttributeData, StubEnvironment env, IGeneratorDiagnostics diagnostics, + AttributeData signatureWideMarshallingAttributeData, Assembly generatorInfoAssembly) { - ImmutableArray typeInfos = GenerateTypeInformation(method, interopAttributeData, diagnostics, env); + ImmutableArray typeInfos = GenerateTypeInformation(method, interopAttributeData, diagnostics, env, signatureWideMarshallingAttributeData); ImmutableArray.Builder additionalAttrs = ImmutableArray.CreateBuilder(); @@ -99,7 +100,12 @@ public IEnumerable StubParameters }; } - private static ImmutableArray GenerateTypeInformation(IMethodSymbol method, InteropAttributeData interopAttributeData, IGeneratorDiagnostics diagnostics, StubEnvironment env) + private static ImmutableArray GenerateTypeInformation( + IMethodSymbol method, + InteropAttributeData interopAttributeData, + IGeneratorDiagnostics diagnostics, + StubEnvironment env, + AttributeData signatureWideMarshallingAttributeData) { // Compute the current default string encoding value. CharEncoding defaultEncoding = CharEncoding.Undefined; @@ -120,14 +126,32 @@ private static ImmutableArray GenerateTypeInformation(IMethodS var defaultInfo = new DefaultMarshallingInfo(defaultEncoding, interopAttributeData.StringMarshallingCustomType); - var marshallingAttributeParser = new MarshallingAttributeInfoParser(env.Compilation, diagnostics, defaultInfo, method); + var useSiteAttributeParsers = ImmutableArray.Create( + new MarshalAsAttributeParser(env.Compilation, diagnostics, defaultInfo), + new MarshalUsingAttributeParser(env.Compilation, diagnostics)); + + var marshallingInfoParser = new MarshallingInfoParser( + diagnostics, + new MethodSignatureElementInfoProvider(env.Compilation, diagnostics, method, useSiteAttributeParsers), + useSiteAttributeParsers, + ImmutableArray.Create( + new MarshalAsAttributeParser(env.Compilation, diagnostics, defaultInfo), + new MarshalUsingAttributeParser(env.Compilation, diagnostics), + new NativeMarshallingAttributeParser(env.Compilation, diagnostics)), + ImmutableArray.Create( + new SafeHandleMarshallingInfoProvider(env.Compilation, method.ContainingType), + new ArrayMarshallingInfoProvider(env.Compilation), + new CharMarshallingInfoProvider(defaultInfo), + new StringMarshallingInfoProvider(env.Compilation, diagnostics, signatureWideMarshallingAttributeData, defaultInfo), + new BooleanMarshallingInfoProvider(), + new BlittableTypeMarshallingInfoProvider(env.Compilation))); // Determine parameter and return types ImmutableArray.Builder typeInfos = ImmutableArray.CreateBuilder(); for (int i = 0; i < method.Parameters.Length; i++) { IParameterSymbol param = method.Parameters[i]; - MarshallingInfo marshallingInfo = marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); + MarshallingInfo marshallingInfo = marshallingInfoParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); typeInfo = typeInfo with { @@ -137,7 +161,7 @@ private static ImmutableArray GenerateTypeInformation(IMethodS typeInfos.Add(typeInfo); } - TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); + TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingInfoParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs new file mode 100644 index 00000000000..e8c9aee264d --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + + /// + /// This class supports generating marshalling info for the type. + /// This includes support for the System.Runtime.InteropServices.StringMarshalling enum. + /// + public sealed class StringMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + private readonly AttributeData _stringMarshallingCustomAttribute; + private readonly DefaultMarshallingInfo _defaultMarshallingInfo; + + public StringMarshallingInfoProvider(Compilation compilation, IGeneratorDiagnostics diagnostics, AttributeData stringMarshallingCustomAttribute, DefaultMarshallingInfo defaultMarshallingInfo) + { + _compilation = compilation; + _diagnostics = diagnostics; + _stringMarshallingCustomAttribute = stringMarshallingCustomAttribute; + _defaultMarshallingInfo = defaultMarshallingInfo; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_String; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + if (_defaultMarshallingInfo.CharEncoding == CharEncoding.Undefined) + { + return NoMarshallingInfo.Instance; + } + else if (_defaultMarshallingInfo.CharEncoding == CharEncoding.Custom) + { + if (_defaultMarshallingInfo.StringMarshallingCustomType is not null) + { + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out var useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + _defaultMarshallingInfo.StringMarshallingCustomType, + _stringMarshallingCustomAttribute, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation); + } + } + else + { + // No marshalling info was computed, but a character encoding was provided. + return _defaultMarshallingInfo.CharEncoding switch + { + CharEncoding.Utf16 => CreateStringMarshallingInfo(_compilation, type, TypeNames.Utf16StringMarshaller), + CharEncoding.Utf8 => CreateStringMarshallingInfo(_compilation, type, TypeNames.Utf8StringMarshaller), + _ => throw new InvalidOperationException() + }; + } + + return new MarshallingInfoStringSupport(_defaultMarshallingInfo.CharEncoding); + } + + public static MarshallingInfo CreateStringMarshallingInfo( + Compilation compilation, + ITypeSymbol type, + string marshallerName) + { + INamedTypeSymbol? stringMarshaller = compilation.GetTypeByMetadataName(marshallerName); + if (stringMarshaller is null) + return new MissingSupportMarshallingInfo(); + + if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(stringMarshaller)) + { + if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(stringMarshaller, type, compilation, out CustomTypeMarshallers? marshallers)) + { + return new NativeMarshallingAttributeInfo( + EntryPointType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(stringMarshaller), + Marshallers: marshallers.Value); + } + } + + return new MissingSupportMarshallingInfo(); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs new file mode 100644 index 00000000000..fa19c1991e2 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Parses information for all use-site attributes and tracks their usage. + /// + public sealed class UseSiteAttributeProvider + { + private readonly ImmutableDictionary _useSiteAttributesByIndirectionDepth; + private readonly int _maxIndirectionLevelDataProvided; + private readonly IGeneratorDiagnostics _diagnostics; + private int _maxIndirectionLevelUsed; + + /// + /// Construct a new for a given usage site. + /// + /// The parsers for the attributes at the given usage site. + /// The attributes at the usage site. + /// The provider for additional element information, used by the attribute parsers. + /// Diagnostics sink for any invalid configurations. + /// A callback to get marshalling information for other elements. Used by . + internal UseSiteAttributeProvider( + ImmutableArray useSiteAttributeParsers, + IEnumerable useSiteAttributes, + IElementInfoProvider elementInfoProvider, + IGeneratorDiagnostics diagnostics, + GetMarshallingInfoCallback getMarshallingInfoCallback) + { + ImmutableDictionary.Builder useSiteAttributesByIndirectionDepth = ImmutableDictionary.CreateBuilder(); + _maxIndirectionLevelDataProvided = 0; + foreach (AttributeData attribute in useSiteAttributes) + { + UseSiteAttributeData? useSiteAttributeData = GetUseSiteInfoForAttribute(attribute); + if (useSiteAttributeData is not null) + { + int indirectionDepth = useSiteAttributeData.IndirectionDepth; + if (useSiteAttributesByIndirectionDepth.ContainsKey(indirectionDepth)) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attribute, nameof(SR.DuplicateMarshallingInfo), indirectionDepth.ToString()); + } + else + { + useSiteAttributesByIndirectionDepth.Add(indirectionDepth, useSiteAttributeData); + _maxIndirectionLevelDataProvided = Math.Max(_maxIndirectionLevelDataProvided, indirectionDepth); + } + } + } + _useSiteAttributesByIndirectionDepth = useSiteAttributesByIndirectionDepth.ToImmutable(); + _diagnostics = diagnostics; + + UseSiteAttributeData? GetUseSiteInfoForAttribute(AttributeData attribute) + { + foreach (var parser in useSiteAttributeParsers) + { + // Automatically ignore invalid attributes. + // The compiler will already error on them. + if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass)) + { + return parser.ParseAttribute(attribute, elementInfoProvider, getMarshallingInfoCallback); + } + } + return null; + } + } + + /// + /// Get the provided for a given , if it exists. + /// + /// The indirection depth to retrieve info for. + /// The use site information, if it exists. + /// true if an attribute was provided for the given indirection depth. + public bool TryGetUseSiteAttributeInfo(int indirectionDepth, out UseSiteAttributeData useSiteInfo) + { + _maxIndirectionLevelUsed = Math.Max(indirectionDepth, _maxIndirectionLevelUsed); + return _useSiteAttributesByIndirectionDepth.TryGetValue(indirectionDepth, out useSiteInfo); + } + + /// + /// Call when no more of the use-site attribute information will be used. + /// Records any information or diagnostics about unused marshalling information. + /// + internal void OnAttributeUsageFinished() + { + if (_maxIndirectionLevelUsed < _maxIndirectionLevelDataProvided) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo( + _useSiteAttributesByIndirectionDepth[_maxIndirectionLevelDataProvided].AttributeData, + nameof(SR.ExtraneousMarshallingInfo), + _maxIndirectionLevelDataProvided.ToString(), + _maxIndirectionLevelUsed.ToString()); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs index 5a77bb12d7b..4d10887e344 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs @@ -120,8 +120,8 @@ public static IEnumerable CodeSnippetsToCompile() // Generic collection marshaller has different arity than collection. yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.GenericCollectionMarshallingArityMismatch, 2, 0 }; - yield return new object[] { ID(), CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 2, 0 }; - yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingDuplicateElementIndirectionDepth, 2, 0 }; + yield return new object[] { ID(), CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 1, 0 }; + yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingDuplicateElementIndirectionDepth, 1, 0 }; yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingUnusedElementIndirectionDepth, 1, 0 }; yield return new object[] { ID(), CodeSnippets.RecursiveCountElementNameOnReturnValue, 2, 0 }; yield return new object[] { ID(), CodeSnippets.RecursiveCountElementNameOnParameter, 2, 0 }; -- GitLab