diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..58ed34812e7df0bbae1d5f4828c3d3912ad5ce19 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/FallbackJSMarshallingInfoProvider.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop.JavaScript +{ + /// + /// Always returns a JSMissingMarshallingInfo. + /// + internal sealed class FallbackJSMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => true; + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + return new JSMissingMarshallingInfo(JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs index 1318358bea46325e15f87fdd880b87481305fd21..0dba53f9bf0b63cde244208d536fb9f3b2114d77 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs @@ -40,9 +40,9 @@ internal sealed class JSExportCodeGenerator : JSCodeGenerator } // validate task + span mix - if (_marshallers.ManagedReturnMarshaller.TypeInfo.ManagedType is JSTaskTypeInfo) + if (_marshallers.ManagedReturnMarshaller.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSTaskTypeInfo)) { - BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.ManagedType is JSSpanTypeInfo); + BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSSpanTypeInfo)); if (spanArg != default) { marshallingNotSupportedCallback(spanArg.TypeInfo, new MarshallingNotSupportedException(spanArg.TypeInfo, _context) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs index f908872318b5f0cdac19bb036c02c617c47a6a9f..7c000011bf046212bc0726eb9d80aa9c34c6a3f6 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSGeneratorFactory.cs @@ -7,6 +7,7 @@ using Microsoft.CodeAnalysis; using System.Collections.Generic; using System.Diagnostics; +using System.Data; namespace Microsoft.Interop.JavaScript { @@ -34,51 +35,51 @@ Exception fail(string failReason) } bool isToJs = info.ManagedIndex != TypePositionInfo.ReturnIndex ^ context is JSExportCodeContext; - switch (info) + switch (jsMarshalingInfo) { // invalid - case { ManagedType: JSInvalidTypeInfo }: + case { TypeInfo: JSInvalidTypeInfo }: throw new MarshallingNotSupportedException(info, context); // void - case { ManagedType: SpecialTypeInfo sd } when sd.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Discard: - case { ManagedType: SpecialTypeInfo sv } when sv.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Void: - case { ManagedType: SpecialTypeInfo sn } when sn.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.None: - case { ManagedType: SpecialTypeInfo sm } when sm.SpecialType == SpecialType.System_Void && jsMarshalingInfo.JSType == JSTypeFlags.Missing: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Discard }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Void }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.None }: + case { TypeInfo: JSSimpleTypeInfo(KnownManagedType.Void), JSType: JSTypeFlags.Missing }: return new VoidGenerator(jsMarshalingInfo.JSType == JSTypeFlags.Void ? MarshalerType.Void : MarshalerType.Discard); // discard no void - case { } when jsMarshalingInfo.JSType == JSTypeFlags.Discard: + case { JSType: JSTypeFlags.Discard }: throw fail(SR.DiscardOnlyVoid); // primitive - case { ManagedType: JSSimpleTypeInfo simple }: + case { TypeInfo: JSSimpleTypeInfo simple }: return Create(info, isToJs, simple.KnownType, Array.Empty(), jsMarshalingInfo.JSType, Array.Empty(), fail); // nullable - case { ManagedType: JSNullableTypeInfo nullable }: + case { TypeInfo: JSNullableTypeInfo nullable }: return Create(info, isToJs, nullable.KnownType, new[] { nullable.ResultTypeInfo.KnownType }, jsMarshalingInfo.JSType, null, fail); // array - case { ManagedType: JSArrayTypeInfo array }: + case { TypeInfo: JSArrayTypeInfo array }: return Create(info, isToJs, array.KnownType, new[] { array.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // array segment - case { ManagedType: JSArraySegmentTypeInfo segment }: + case { TypeInfo: JSArraySegmentTypeInfo segment }: return Create(info, isToJs, segment.KnownType, new[] { segment.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // span - case { ManagedType: JSSpanTypeInfo span }: + case { TypeInfo: JSSpanTypeInfo span }: return Create(info, isToJs, span.KnownType, new[] { span.ElementTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // task - case { ManagedType: JSTaskTypeInfo task } when task.ResultTypeInfo is JSSimpleTypeInfo taskRes && taskRes.FullTypeName == "void": + case { TypeInfo: JSTaskTypeInfo(JSSimpleTypeInfo(KnownManagedType.Void)) task }: return Create(info, isToJs, task.KnownType, Array.Empty(), jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); - case { ManagedType: JSTaskTypeInfo task }: + case { TypeInfo: JSTaskTypeInfo task }: return Create(info, isToJs, task.KnownType, new[] { task.ResultTypeInfo.KnownType }, jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); // action + function - case { ManagedType: JSFunctionTypeInfo function }: + case { TypeInfo: JSFunctionTypeInfo function }: return Create(info, isToJs, function.KnownType, function.ArgsTypeInfo.Select(a => a.KnownType).ToArray(), jsMarshalingInfo.JSType, jsMarshalingInfo.JSTypeArguments, fail); default: diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs index 7bed0e9df4ea538e3621ad7e0c83552ea24d9f19..9c10b08ef87570b3fc030b5b92e519ecb54e2b80 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs @@ -47,9 +47,9 @@ internal sealed class JSImportCodeGenerator : JSCodeGenerator } // validate task + span mix - if (_marshallers.ManagedReturnMarshaller.TypeInfo.ManagedType is JSTaskTypeInfo) + if (_marshallers.ManagedReturnMarshaller.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSTaskTypeInfo)) { - BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.ManagedType is JSSpanTypeInfo); + BoundGenerator spanArg = _marshallers.AllMarshallers.FirstOrDefault(m => m.TypeInfo.MarshallingAttributeInfo is JSMarshallingInfo(_, JSSpanTypeInfo)); if (spanArg != default) { marshallingNotSupportedCallback(spanArg.TypeInfo, new MarshallingNotSupportedException(spanArg.TypeInfo, _context) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs index 055e81315e68e4384e4b3a67658bc448b3b073a3..84d3cc2b692fdb0740d67d8907e11c5d0075b9a2 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportStubContext.cs @@ -145,19 +145,22 @@ static bool IsSkipLocalsInitAttribute(AttributeData a) private static (ImmutableArray, IMarshallingGeneratorFactory) GenerateTypeInformation(IMethodSymbol method, GeneratorDiagnostics diagnostics, StubEnvironment env) { - var jsMarshallingAttributeParser = new JSMarshallingAttributeInfoParser(env.Compilation, diagnostics, method); + ImmutableArray useSiteAttributeParsers = ImmutableArray.Create(new JSMarshalAsAttributeParser(env.Compilation)); + var jsMarshallingAttributeParser = new MarshallingInfoParser( + diagnostics, + new MethodSignatureElementInfoProvider(env.Compilation, diagnostics, method, useSiteAttributeParsers), + useSiteAttributeParsers, + ImmutableArray.Create(new JSMarshalAsAttributeParser(env.Compilation)), + ImmutableArray.Create(new FallbackJSMarshallingInfoProvider())); // Determine parameter and return types ImmutableArray.Builder typeInfos = ImmutableArray.CreateBuilder(); for (int i = 0; i < method.Parameters.Length; i++) { IParameterSymbol param = method.Parameters[i]; - MarshallingInfo marshallingInfo = NoMarshallingInfo.Instance; - MarshallingInfo jsMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes(), marshallingInfo); + MarshallingInfo jsMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); - var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); - typeInfo = JSTypeInfo.CreateForType(typeInfo, param.Type, jsMarshallingInfo, env.Compilation); - typeInfo = typeInfo with + var typeInfo = TypePositionInfo.CreateForParameter(param, jsMarshallingInfo, env.Compilation) with { ManagedIndex = i, NativeIndex = typeInfos.Count, @@ -165,12 +168,9 @@ private static (ImmutableArray, IMarshallingGeneratorFactory) typeInfos.Add(typeInfo); } - MarshallingInfo retMarshallingInfo = NoMarshallingInfo.Instance; - MarshallingInfo retJSMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), retMarshallingInfo); + MarshallingInfo retJSMarshallingInfo = jsMarshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes()); - var retTypeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), retMarshallingInfo); - retTypeInfo = JSTypeInfo.CreateForType(retTypeInfo, method.ReturnType, retJSMarshallingInfo, env.Compilation); - retTypeInfo = retTypeInfo with + var retTypeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), retJSMarshallingInfo) { ManagedIndex = TypePositionInfo.ReturnIndex, NativeIndex = TypePositionInfo.ReturnIndex, diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs index 0eedd0aa873e55a9af85c9c40b6443e270b14694..0f9084a01e8416879189ef14d21197a7bdd8e21c 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSManagedTypeInfo.cs @@ -4,185 +4,192 @@ using System; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Microsoft.Interop.JavaScript { - internal abstract record JSTypeInfo(string FullTypeName, string DiagnosticFormattedName, KnownManagedType KnownType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName) + internal abstract record JSTypeInfo(KnownManagedType KnownType) { - public static ManagedTypeInfo CreateJSTypeInfoForTypeSymbol(ITypeSymbol type) + public static JSTypeInfo CreateJSTypeInfoForTypeSymbol(ITypeSymbol type) { string fullTypeName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - if (fullTypeName == "void") + switch (type) { - return SpecialTypeInfo.Void; - } - string diagnosticFormattedName = type.ToDisplayString(); - return CreateJSTypeInfoForTypeSymbol(fullTypeName, diagnosticFormattedName); - } - - public static ManagedTypeInfo CreateJSTypeInfoForTypeSymbol(string fullTypeName, string diagnosticFormattedName) - { - switch (fullTypeName.Trim()) - { - case "global::System.Void": - case "void": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Void); - case "global::System.Boolean": - case "bool": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Boolean); - case "global::System.Byte": - case "byte": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Byte); - case "global::System.Char": - case "char": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Char); - case "global::System.Int16": - case "short": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int16); - case "global::System.Int32": - case "int": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int32); - case "global::System.Int64": - case "long": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Int64); - case "global::System.Single": - case "float": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Single); - case "global::System.Double": - case "double": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Double); - case "global::System.IntPtr": - case "nint": - case "void*": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.IntPtr); - case "global::System.DateTime": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.DateTime); - case "global::System.DateTimeOffset": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.DateTimeOffset); - case "global::System.Exception": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Exception); - case "global::System.Object": - case "object": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.Object); - case "global::System.String": - case "string": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.String); - case "global::System.Runtime.InteropServices.JavaScript.JSObject": - return new JSSimpleTypeInfo(fullTypeName, diagnosticFormattedName, KnownManagedType.JSObject); + case { SpecialType: SpecialType.System_Void }: + return new JSSimpleTypeInfo(KnownManagedType.Void) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.VoidKeyword)) + }; + case { SpecialType: SpecialType.System_Boolean }: + return new JSSimpleTypeInfo(KnownManagedType.Boolean) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.BoolKeyword)) + }; + case { SpecialType: SpecialType.System_Byte }: + return new JSSimpleTypeInfo(KnownManagedType.Byte) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)) + }; + case { SpecialType: SpecialType.System_Char }: + return new JSSimpleTypeInfo(KnownManagedType.Char) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)) + }; + case { SpecialType: SpecialType.System_Int16 }: + return new JSSimpleTypeInfo(KnownManagedType.Int16) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ShortKeyword)) + }; + case { SpecialType: SpecialType.System_Int32 }: + return new JSSimpleTypeInfo(KnownManagedType.Int32) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.IntKeyword)) + }; + case { SpecialType: SpecialType.System_Int64 }: + return new JSSimpleTypeInfo(KnownManagedType.Int64) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.LongKeyword)) + }; + case { SpecialType: SpecialType.System_Single }: + return new JSSimpleTypeInfo(KnownManagedType.Single) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.FloatKeyword)) + }; + case { SpecialType: SpecialType.System_Double }: + return new JSSimpleTypeInfo(KnownManagedType.Double) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.DoubleKeyword)) + }; + case { SpecialType: SpecialType.System_IntPtr }: + case IPointerTypeSymbol { PointedAtType.SpecialType: SpecialType.System_Void }: + return new JSSimpleTypeInfo(KnownManagedType.IntPtr) + { + Syntax = SyntaxFactory.IdentifierName("nint") + }; + case { SpecialType: SpecialType.System_DateTime }: + return new JSSimpleTypeInfo(KnownManagedType.DateTime) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case ITypeSymbol when fullTypeName == "global::System.DateTimeOffset": + return new JSSimpleTypeInfo(KnownManagedType.DateTimeOffset) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case ITypeSymbol when fullTypeName == "global::System.Exception": + return new JSSimpleTypeInfo(KnownManagedType.Exception) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; + case { SpecialType: SpecialType.System_Object }: + return new JSSimpleTypeInfo(KnownManagedType.Object) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ObjectKeyword)) + }; + case { SpecialType: SpecialType.System_String }: + return new JSSimpleTypeInfo(KnownManagedType.String) + { + Syntax = SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)) + }; + case ITypeSymbol when fullTypeName == "global::System.Runtime.InteropServices.JavaScript.JSObject": + return new JSSimpleTypeInfo(KnownManagedType.JSObject) + { + Syntax = SyntaxFactory.ParseTypeName(fullTypeName.Trim()) + }; //nullable - case string ftn when ftn.EndsWith("?"): - var ut = fullTypeName.Remove(fullTypeName.Length - 1); - if (CreateJSTypeInfoForTypeSymbol(ut, diagnosticFormattedName) is JSSimpleTypeInfo uti) + case INamedTypeSymbol { ConstructedFrom.SpecialType: SpecialType.System_Nullable_T } nullable: + if (CreateJSTypeInfoForTypeSymbol(nullable.TypeArguments[0]) is JSSimpleTypeInfo uti) { - return new JSNullableTypeInfo(fullTypeName, diagnosticFormattedName, uti); + return new JSNullableTypeInfo(uti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // array - case string ftn when ftn.EndsWith("[]"): - var et = fullTypeName.Remove(fullTypeName.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(et, diagnosticFormattedName) is JSSimpleTypeInfo eti) + case IArrayTypeSymbol { IsSZArray: true, ElementType: ITypeSymbol elementType }: + if (CreateJSTypeInfoForTypeSymbol(elementType) is JSSimpleTypeInfo eti) { - return new JSArrayTypeInfo(fullTypeName, diagnosticFormattedName, eti); + return new JSArrayTypeInfo(eti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // task - case Constants.TaskGlobal: - return new JSTaskTypeInfo(fullTypeName, diagnosticFormattedName, (JSSimpleTypeInfo)CreateJSTypeInfoForTypeSymbol("void", diagnosticFormattedName)); - case string ft when ft.StartsWith(Constants.TaskGlobal): - var rt = fullTypeName.Substring(Constants.TaskGlobal.Length + 1, fullTypeName.Length - Constants.TaskGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(rt, diagnosticFormattedName) is JSSimpleTypeInfo rti) + case ITypeSymbol when fullTypeName == Constants.TaskGlobal: + return new JSTaskTypeInfo(new JSSimpleTypeInfo(KnownManagedType.Void, SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.VoidKeyword)))); + case INamedTypeSymbol { TypeArguments.Length: 1 } taskType when fullTypeName.StartsWith(Constants.TaskGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(taskType.TypeArguments[0]) is JSSimpleTypeInfo rti) { - return new JSTaskTypeInfo(fullTypeName, diagnosticFormattedName, rti); + return new JSTaskTypeInfo(rti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // span - case string ft when ft.StartsWith(Constants.SpanGlobal): - var st = fullTypeName.Substring(Constants.SpanGlobal.Length + 1, fullTypeName.Length - Constants.SpanGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(st, diagnosticFormattedName) is JSSimpleTypeInfo sti) + case INamedTypeSymbol { TypeArguments.Length: 1 } spanType when fullTypeName.StartsWith(Constants.SpanGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(spanType.TypeArguments[0]) is JSSimpleTypeInfo sti) { - return new JSSpanTypeInfo(fullTypeName, diagnosticFormattedName, sti); + return new JSSpanTypeInfo(sti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // array segment - case string ft when ft.StartsWith(Constants.ArraySegmentGlobal): - var gt = fullTypeName.Substring(Constants.ArraySegmentGlobal.Length + 1, fullTypeName.Length - Constants.ArraySegmentGlobal.Length - 2); - if (CreateJSTypeInfoForTypeSymbol(gt, diagnosticFormattedName) is JSSimpleTypeInfo gti) + case INamedTypeSymbol { TypeArguments.Length: 1 } arraySegmentType when fullTypeName.StartsWith(Constants.ArraySegmentGlobal, StringComparison.Ordinal): + if (CreateJSTypeInfoForTypeSymbol(arraySegmentType.TypeArguments[0]) is JSSimpleTypeInfo gti) { - return new JSArraySegmentTypeInfo(fullTypeName, diagnosticFormattedName, gti); + return new JSArraySegmentTypeInfo(gti); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); // action - case Constants.ActionGlobal: - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, true, Array.Empty()); - case string ft when ft.StartsWith(Constants.ActionGlobal): - var argNames = fullTypeName.Substring(Constants.ActionGlobal.Length + 1, fullTypeName.Length - Constants.ActionGlobal.Length - 2); - if (!argNames.Contains("<")) - { - var ga = argNames.Split(',') - .Select(argName => CreateJSTypeInfoForTypeSymbol(argName, diagnosticFormattedName) as JSSimpleTypeInfo) - .ToArray(); - if (ga.Any(x => x == null)) - { - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); - } - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, true, ga); + case ITypeSymbol when fullTypeName == Constants.ActionGlobal: + return new JSFunctionTypeInfo(true, Array.Empty()); + case INamedTypeSymbol actionType when fullTypeName.StartsWith(Constants.ActionGlobal, StringComparison.Ordinal): + var argumentTypes = actionType.TypeArguments + .Select(arg => CreateJSTypeInfoForTypeSymbol(arg) as JSSimpleTypeInfo) + .ToArray(); + if (argumentTypes.Any(x => x is null)) + { + return new JSInvalidTypeInfo(); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSFunctionTypeInfo(true, argumentTypes); // function - case string ft when ft.StartsWith(Constants.FuncGlobal): - var fargNames = fullTypeName.Substring(Constants.FuncGlobal.Length + 1, fullTypeName.Length - Constants.FuncGlobal.Length - 2); - if (!fargNames.Contains("<")) - { - var ga = fargNames.Split(',') - .Select(argName => CreateJSTypeInfoForTypeSymbol(argName, diagnosticFormattedName) as JSSimpleTypeInfo) - .ToArray(); - if (ga.Any(x => x == null)) - { - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); - } - return new JSFunctionTypeInfo(fullTypeName, diagnosticFormattedName, false, ga); + case INamedTypeSymbol funcType when fullTypeName.StartsWith(Constants.FuncGlobal, StringComparison.Ordinal): + var signatureTypes = funcType.TypeArguments + .Select(argName => CreateJSTypeInfoForTypeSymbol(argName) as JSSimpleTypeInfo) + .ToArray(); + if (signatureTypes.Any(x => x is null)) + { + return new JSInvalidTypeInfo(); } - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSFunctionTypeInfo(false, signatureTypes); default: - return new JSInvalidTypeInfo(fullTypeName, diagnosticFormattedName); + return new JSInvalidTypeInfo(); } } + } - public static TypePositionInfo CreateForType(TypePositionInfo inner, ITypeSymbol type, MarshallingInfo jsMarshallingInfo, Compilation compilation) - { - ManagedTypeInfo jsTypeInfo = CreateJSTypeInfoForTypeSymbol(type); - var typeInfo = new TypePositionInfo(jsTypeInfo, jsMarshallingInfo) - { - InstanceIdentifier = inner.InstanceIdentifier, - RefKind = inner.RefKind, - RefKindSyntax = inner.RefKindSyntax, - ByValueContentsMarshalKind = inner.ByValueContentsMarshalKind - }; + internal sealed record JSInvalidTypeInfo() : JSSimpleTypeInfo(KnownManagedType.None); - return typeInfo; + internal record JSSimpleTypeInfo(KnownManagedType KnownType) : JSTypeInfo(KnownType) + { + public JSSimpleTypeInfo(KnownManagedType knownType, TypeSyntax syntax) + : this(knownType) + { + Syntax = syntax; } + public TypeSyntax Syntax { get; init; } } - internal sealed record JSInvalidTypeInfo(string FullTypeName, string DiagnosticFormattedName) : JSSimpleTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.None); - - internal record JSSimpleTypeInfo(string FullTypeName, string DiagnosticFormattedName, KnownManagedType KnownType) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownType); - - internal sealed record JSArrayTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Array); + internal sealed record JSArrayTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.Array); - internal sealed record JSSpanTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Span); + internal sealed record JSSpanTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.Span); - internal sealed record JSArraySegmentTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.ArraySegment); + internal sealed record JSArraySegmentTypeInfo(JSSimpleTypeInfo ElementTypeInfo) : JSTypeInfo(KnownManagedType.ArraySegment); - internal sealed record JSTaskTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Task); + internal sealed record JSTaskTypeInfo(JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(KnownManagedType.Task); - internal sealed record JSNullableTypeInfo(string FullTypeName, string DiagnosticFormattedName, JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, KnownManagedType.Nullable); + internal sealed record JSNullableTypeInfo(JSSimpleTypeInfo ResultTypeInfo) : JSTypeInfo(KnownManagedType.Nullable); - internal sealed record JSFunctionTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsAction, JSSimpleTypeInfo[] ArgsTypeInfo) : JSTypeInfo(FullTypeName, DiagnosticFormattedName, (IsAction ? KnownManagedType.Action : KnownManagedType.Function)); + internal sealed record JSFunctionTypeInfo(bool IsAction, JSSimpleTypeInfo[] ArgsTypeInfo) : JSTypeInfo(IsAction ? KnownManagedType.Action : KnownManagedType.Function); } diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs new file mode 100644 index 0000000000000000000000000000000000000000..179cc065f03ed498d2198d466c85ad2ad79a8a23 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallAsAttributeInfoParser.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Runtime.InteropServices.JavaScript; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop.JavaScript +{ + internal sealed class JSMarshalAsAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly INamedTypeSymbol _jsMarshalAsAttribute; + + public JSMarshalAsAttributeParser(Compilation compilation) + { + _jsMarshalAsAttribute = compilation.GetTypeByMetadataName(Constants.JSMarshalAsAttribute)!.ConstructUnboundGenericType(); + } + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.IsGenericType && SymbolEqualityComparer.Default.Equals(_jsMarshalAsAttribute, attributeType.ConstructUnboundGenericType()); + public MarshallingInfo ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + JSTypeFlags jsType = JSTypeFlags.None; + List jsTypeArguments = new List(); + INamedTypeSymbol? jsTypeArgs = attributeData.AttributeClass.TypeArguments[0] as INamedTypeSymbol; + if (jsTypeArgs.IsGenericType) + { + string gt = jsTypeArgs.ConstructUnboundGenericType().ToDisplayString(); + string name = gt.Substring(gt.IndexOf("JSType") + "JSType.".Length); + name = name.Substring(0, name.IndexOf("<")); + + Enum.TryParse(name, out jsType); + + foreach (var ta in jsTypeArgs.TypeArguments.Cast().Select(x => x.ToDisplayString())) + { + string argName = ta.Substring(ta.IndexOf("JSType") + "JSType.".Length); + JSTypeFlags jsTypeArg = JSTypeFlags.None; + Enum.TryParse(argName, out jsTypeArg); + jsTypeArguments.Add(jsTypeArg); + } + } + else + { + string st = jsTypeArgs.ToDisplayString(); + string name = st.Substring(st.IndexOf("JSType") + "JSType.".Length); + Enum.TryParse(name, out jsType); + } + + if (jsType == JSTypeFlags.None) + { + return new JSMissingMarshallingInfo(JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)); + } + + return new JSMarshallingInfo(NoMarshallingInfo.Instance, JSTypeInfo.CreateJSTypeInfoForTypeSymbol(type)) + { + JSType = jsType, + JSTypeArguments = jsTypeArguments.ToArray(), + }; + } + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + return new UseSiteAttributeData(0, NoCountInfo.Instance, attributeData); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs deleted file mode 100644 index b6fdaa576b5783487ab816bda0d3ad88cae40ae7..0000000000000000000000000000000000000000 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingAttributeInfoParser.cs +++ /dev/null @@ -1,86 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices.JavaScript; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Interop.JavaScript -{ - public sealed class JSMarshallingAttributeInfoParser - { - private readonly ITypeSymbol _jsMarshalAsAttribute; - private readonly ITypeSymbol _marshalUsingAttribute; - - public JSMarshallingAttributeInfoParser( - Compilation compilation, - IGeneratorDiagnostics diagnostics, - ISymbol contextSymbol) - { - _jsMarshalAsAttribute = compilation.GetTypeByMetadataName(Constants.JSMarshalAsAttribute)!.ConstructUnboundGenericType(); - _marshalUsingAttribute = compilation.GetTypeByMetadataName(Constants.MarshalUsingAttribute)!; - } - - public MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes, - MarshallingInfo inner) - { - JSTypeFlags jsType = JSTypeFlags.None; - List jsTypeArguments = new List(); - - foreach (AttributeData useSiteAttribute in useSiteAttributes) - { - INamedTypeSymbol attributeClass = useSiteAttribute.AttributeClass!; - if (attributeClass.IsGenericType && SymbolEqualityComparer.Default.Equals(_jsMarshalAsAttribute, attributeClass.ConstructUnboundGenericType())) - { - INamedTypeSymbol? jsTypeArgs = attributeClass.TypeArguments[0] as INamedTypeSymbol; - if (jsTypeArgs.IsGenericType) - { - string gt = jsTypeArgs.ConstructUnboundGenericType().ToDisplayString(); - string name = gt.Substring(gt.IndexOf("JSType") + "JSType.".Length); - name = name.Substring(0, name.IndexOf("<")); - - Enum.TryParse(name, out jsType); - - foreach (var ta in jsTypeArgs.TypeArguments.Cast().Select(x => x.ToDisplayString())) - { - string argName = ta.Substring(ta.IndexOf("JSType") + "JSType.".Length); - JSTypeFlags jsTypeArg = JSTypeFlags.None; - Enum.TryParse(argName, out jsTypeArg); - jsTypeArguments.Add(jsTypeArg); - } - } - else - { - string st = jsTypeArgs.ToDisplayString(); - string name = st.Substring(st.IndexOf("JSType") + "JSType.".Length); - Enum.TryParse(name, out jsType); - } - - } - if (SymbolEqualityComparer.Default.Equals(_marshalUsingAttribute, attributeClass)) - { - return new JSMarshallingInfo(inner) - { - JSType = JSTypeFlags.Array, - JSTypeArguments = Array.Empty(), - }; - } - } - - if (jsType == JSTypeFlags.None) - { - return new JSMissingMarshallingInfo(); - } - - return new JSMarshallingInfo(inner) - { - JSType = jsType, - JSTypeArguments = jsTypeArguments.ToArray(), - }; - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs index d18f67b47928cd440243d6e65aeffd369cfab3e9..edeb7961aadb1eb40dc607ba8541f4e2a37a2600 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSMarshallingInfo.cs @@ -6,26 +6,24 @@ namespace Microsoft.Interop.JavaScript { - internal record JSMarshallingInfo : MarshallingInfo + internal record JSMarshallingInfo(MarshallingInfo Inner, JSTypeInfo TypeInfo) : MarshallingInfo { - public MarshallingInfo Inner; - public JSTypeFlags JSType; - public JSTypeFlags[] JSTypeArguments; - public JSMarshallingInfo(MarshallingInfo inner) - { - Inner = inner; - } protected JSMarshallingInfo() + :this(NoMarshallingInfo.Instance, new JSInvalidTypeInfo()) { Inner = null; } + + public JSTypeFlags JSType { get; init; } + public JSTypeFlags[] JSTypeArguments { get; init; } } internal sealed record JSMissingMarshallingInfo : JSMarshallingInfo { - public JSMissingMarshallingInfo() + public JSMissingMarshallingInfo(JSTypeInfo typeInfo) { JSType = JSTypeFlags.Missing; + TypeInfo = typeInfo; } } } diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs index fed51dde6d6f6c880cce6b687b9478ae4e87b4a2..ecbb1fcaa76c9e09b91501924cfb980a676ceb55 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/FuncJSGenerator.cs @@ -48,9 +48,9 @@ public override IEnumerable Generate(TypePositionInfo info, Stu ? Argument(IdentifierName(context.GetIdentifiers(info).native)) : _inner.AsArgument(info, context); - var jsty = (JSFunctionTypeInfo)info.ManagedType; + var jsty = (JSFunctionTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; var sourceTypes = jsty.ArgsTypeInfo - .Select(a => ParseTypeName(a.FullTypeName)) + .Select(a => a.Syntax) .ToArray(); if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.In && info.IsManagedReturnPosition) diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs index 86162f0e62c945e19578a8206d4398557489372b..df9cee6af62ba257fea4da8586a330c7688b9a6c 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Marshaling/TaskJSGenerator.cs @@ -20,8 +20,8 @@ public TaskJSGenerator(MarshalerType resultMarshalerType) public override IEnumerable GenerateBind(TypePositionInfo info, StubCodeContext context) { - var jsty = (JSTaskTypeInfo)info.ManagedType; - if (jsty.ResultTypeInfo.FullTypeName == "void") + var jsty = (JSTaskTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; + if (jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void)) { yield return InvocationExpression(MarshalerTypeName(MarshalerType.Task), ArgumentList()); } @@ -34,7 +34,7 @@ public override IEnumerable GenerateBind(TypePositionInfo info public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { - var jsty = (JSTaskTypeInfo)info.ManagedType; + var jsty = (JSTaskTypeInfo)((JSMarshallingInfo)info.MarshallingAttributeInfo).TypeInfo; string argName = context.GetAdditionalIdentifier(info, "js_arg"); var target = info.IsManagedReturnPosition @@ -47,14 +47,14 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.In && info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToManagedMethodVoid(target, source) : ToManagedMethod(target, source, jsty.ResultTypeInfo.Syntax); } if (context.CurrentStage == StubCodeContext.Stage.Marshal && context.Direction == CustomTypeMarshallingDirection.Out && info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToJSMethodVoid(target, source) : ToJSMethod(target, source, jsty.ResultTypeInfo.Syntax); } @@ -66,14 +66,14 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (context.CurrentStage == StubCodeContext.Stage.Invoke && context.Direction == CustomTypeMarshallingDirection.In && !info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToJSMethodVoid(target, source) : ToJSMethod(target, source, jsty.ResultTypeInfo.Syntax); } if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && context.Direction == CustomTypeMarshallingDirection.Out && !info.IsManagedReturnPosition) { - yield return jsty.ResultTypeInfo.FullTypeName == "void" + yield return jsty.ResultTypeInfo is JSSimpleTypeInfo(KnownManagedType.Void) ? ToManagedMethodVoid(target, source) : ToManagedMethod(target, source, jsty.ResultTypeInfo.Syntax); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs index 204709cf60c86cca1e7c2016b43e857584c3035e..be23c9e3e555c5cae54a2245f25358172c36f6a0 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportAnalyzer.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; @@ -86,7 +87,8 @@ private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbo // later user work. AnyDiagnosticsSink diagnostics = new(); StubEnvironment env = context.Compilation.CreateStubEnvironment(); - SignatureContext targetSignatureContext = SignatureContext.Create(method, CreateInteropAttributeDataFromDllImport(dllImportData), env, diagnostics, typeof(ConvertToLibraryImportAnalyzer).Assembly); + AttributeData dllImportAttribute = method.GetAttributes().First(attr => attr.AttributeClass.ToDisplayString() == TypeNames.DllImportAttribute); + SignatureContext targetSignatureContext = SignatureContext.Create(method, CreateInteropAttributeDataFromDllImport(dllImportData), env, diagnostics, dllImportAttribute, typeof(ConvertToLibraryImportAnalyzer).Assembly); var generatorFactoryKey = LibraryImportGeneratorHelpers.CreateGeneratorFactory(env, new LibraryImportGeneratorOptions(context.Options.AnalyzerConfigOptionsProvider.GlobalOptions)); diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs index 1da4dbd49afeca4bdfdae652fdc473aa8914df5c..f74db2f8d845ea9a3287fcd2abe37fb43f567b44 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs @@ -312,7 +312,7 @@ private static SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenLis } // Create the stub. - var signatureContext = SignatureContext.Create(symbol, libraryImportData, environment, generatorDiagnostics, typeof(LibraryImportGenerator).Assembly); + var signatureContext = SignatureContext.Create(symbol, libraryImportData, environment, generatorDiagnostics, generatedDllImportAttr, typeof(LibraryImportGenerator).Assembly); var containingTypeContext = new ContainingSyntaxContext(originalSyntax); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..2c58acff79b16814cbf0c8ebf9b8c9c770efd568 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ArrayMarshallingInfoProvider.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for single-dimensional zero-based array types using the System.Runtime.InteropServices.Marshalling.ArrayMarshaller and System.Runtime.InteropServices.Marshalling.PointerArrayMarshaller + /// built-in types. + /// + public sealed class ArrayMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + + public ArrayMarshallingInfoProvider(Compilation compilation) + { + _compilation = compilation; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type is IArrayTypeSymbol { IsSZArray: true }; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + ITypeSymbol elementType = ((IArrayTypeSymbol)type).ElementType; + return CreateArrayMarshallingInfo(_compilation, type, elementType, countInfo, marshallingInfoCallback(elementType, useSiteAttributes, indirectionDepth + 1)); + } + + public static MarshallingInfo CreateArrayMarshallingInfo( + Compilation compilation, + ITypeSymbol managedType, + ITypeSymbol elementType, + CountInfo countInfo, + MarshallingInfo elementMarshallingInfo) + { + ITypeSymbol typeArgumentToInsert = elementType; + INamedTypeSymbol? arrayMarshaller; + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_PointerArrayMarshaller_Metadata); + typeArgumentToInsert = pointedAt; + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_ArrayMarshaller_Metadata); + } + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. + return new MissingSupportCollectionMarshallingInfo(countInfo, elementMarshallingInfo); + } + + if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(arrayMarshaller) + && ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(arrayMarshaller)) + { + arrayMarshaller = arrayMarshaller.Construct( + typeArgumentToInsert, + arrayMarshaller.TypeArguments.Last()); + + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => elementMarshallingInfo; + if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(arrayMarshaller, managedType, compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers)) + { + return new NativeLinearCollectionMarshallingInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), + marshallers.Value, + countInfo, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller.TypeParameters.Last())); + } + } + + Debug.WriteLine("Default marshallers for arrays should be a valid shape."); + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..cc3a1996957f19614ab96507de41f8693fef8c85 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BlittableTypeMarshallingInfoProvider.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for unmanaged types that may be blittable. + /// + public sealed class BlittableTypeMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + + public BlittableTypeMarshallingInfoProvider(Compilation compilation) + { + _compilation = compilation; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type is INamedTypeSymbol { IsUnmanagedType: true } unmanagedType + && unmanagedType.IsConsideredBlittable(); + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + if (type.TypeKind is TypeKind.Enum or TypeKind.Pointer or TypeKind.FunctionPointer + || type.SpecialType.IsAlwaysBlittable()) + { + // Treat primitive types and enums as having no marshalling info. + // They are supported in configurations where runtime marshalling is enabled. + return NoMarshallingInfo.Instance; + } + else if (_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute) is null) + { + // If runtime marshalling cannot be disabled, then treat this as a "missing support" scenario so we can gracefully fall back to using the forwarder downlevel. + return new MissingSupportMarshallingInfo(); + } + else + { + return new UnmanagedBlittableMarshallingInfo(type.IsStrictlyBlittable()); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..6a7b78acf1dd56db4f6847c5b156c727b5dff47f --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/BooleanMarshallingInfoProvider.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for bool elements without any marshalling information. + /// + public sealed class BooleanMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_Boolean; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + // We intentionally don't support marshalling bool with no marshalling info + // as treating bool as a non-normalized 1-byte value is generally not a good default. + // Additionally, that default is different than the runtime marshalling, so by explicitly + // blocking bool marshalling without additional info, we make it a little easier + // to transition by explicitly notifying people of changing behavior. + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..b8b6f4c4c1608a056965c4cb058cb49619992025 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CharMarshallingInfoProvider.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Marshalling information provider for char elements without any marshalling information on the element itself. + /// + public sealed class CharMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly DefaultMarshallingInfo _defaultMarshallingInfo; + + public CharMarshallingInfoProvider(DefaultMarshallingInfo defaultMarshallingInfo) + { + _defaultMarshallingInfo = defaultMarshallingInfo; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_Char; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + // No marshalling info was computed, but a character encoding was provided. + // If the type is a character then pass on these details. + return _defaultMarshallingInfo.CharEncoding == CharEncoding.Undefined ? new UnmanagedBlittableMarshallingInfo(IsStrictlyBlittable: false) : new MarshallingInfoStringSupport(_defaultMarshallingInfo.CharEncoding); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs new file mode 100644 index 0000000000000000000000000000000000000000..d1df5d4e19ed822cd33d6e89be1fb378d85bacb4 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/CustomMarshallingInfoHelper.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + internal static class CustomMarshallingInfoHelper + { + public static MarshallingInfo CreateNativeMarshallingInfo( + ITypeSymbol type, + INamedTypeSymbol entryPointType, + AttributeData attrData, + UseSiteAttributeProvider useSiteAttributeProvider, + GetMarshallingInfoCallback getMarshallingInfoCallback, + int indirectionDepth, + CountInfo parsedCountInfo, + IGeneratorDiagnostics diagnostics, + Compilation compilation) + { + if (!ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) + { + return NoMarshallingInfo.Instance; + } + + if (!(entryPointType.IsStatic && entryPointType.TypeKind == TypeKind.Class) + && entryPointType.TypeKind != TypeKind.Struct) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerTypeMustBeStaticClassOrStruct), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + ManagedTypeInfo entryPointTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType); + + bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); + if (isLinearCollectionMarshalling) + { + // Update the entry point type with the type arguments based on the managed type + if (type is IArrayTypeSymbol arrayManagedType) + { + // Generally, we require linear collection marshallers to have "arity of managed type + 1" arity. + // However, arrays aren't "generic" over their element type as they're generics, but we want to treat the element type + // as a generic type parameter. As a result, we require an arity of 2 for array marshallers, 1 for the array element type, + // and 1 for the native element type (the required additional type parameter for linear collection marshallers). + if (entryPointType.Arity != 2) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + entryPointType = entryPointType.ConstructedFrom.Construct( + arrayManagedType.ElementType, + entryPointType.TypeArguments.Last()); + } + else if (type is INamedTypeSymbol namedManagedCollectionType && entryPointType.IsUnboundGenericType) + { + if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( + namedManagedCollectionType, + entryPointType, + isLinearCollectionMarshalling, + (type, entryPointType) => diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), + out ITypeSymbol resolvedEntryPointType)) + { + return NoMarshallingInfo.Instance; + } + + entryPointType = (INamedTypeSymbol)resolvedEntryPointType; + } + else + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => getMarshallingInfoCallback(elementType, useSiteAttributeProvider, indirectionDepth + 1); + if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers)) + { + return new NativeLinearCollectionMarshallingInfo( + entryPointTypeInfo, + collectionMarshallers.Value, + parsedCountInfo, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType.TypeParameters.Last())); + } + return NoMarshallingInfo.Instance; + } + + if (type is INamedTypeSymbol namedManagedType && entryPointType.IsUnboundGenericType) + { + if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( + namedManagedType, + entryPointType, + isLinearCollectionMarshalling, + (type, entryPointType) => diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), + out ITypeSymbol resolvedEntryPointType)) + { + return NoMarshallingInfo.Instance; + } + + entryPointType = (INamedTypeSymbol)resolvedEntryPointType; + } + + if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(entryPointType, type, compilation, out CustomTypeMarshallers? marshallers)) + { + return new NativeMarshallingAttributeInfo(entryPointTypeInfo, marshallers.Value); + } + return NoMarshallingInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs new file mode 100644 index 0000000000000000000000000000000000000000..4294f11610f20513cfc449d6ed2ef02890e6e9bb --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Simple User-application of System.Runtime.InteropServices.MarshalAsAttribute + /// + public sealed record MarshalAsInfo( + UnmanagedType UnmanagedType, + CharEncoding CharEncoding) : MarshallingInfoStringSupport(CharEncoding) + { + // UnmanagedType.LPUTF8Str is not in netstandard2.0, so we define a constant for the value here. + // See https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.unmanagedtype + internal const UnmanagedType UnmanagedType_LPUTF8Str = (UnmanagedType)0x30; + } + + /// + /// This class suppports parsing a System.Runtime.InteropServices.MarshalAsAttribute. + /// + public sealed class MarshalAsAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + private readonly DefaultMarshallingInfo _defaultInfo; + + public MarshalAsAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics, DefaultMarshallingInfo defaultInfo) + { + _compilation = compilation; + _diagnostics = diagnostics; + _defaultInfo = defaultInfo; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute; + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + ImmutableDictionary namedArguments = ImmutableDictionary.CreateRange(attributeData.NamedArguments); + SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; + if (namedArguments.TryGetValue(nameof(MarshalAsAttribute.SizeConst), out TypedConstant sizeConstArg)) + { + arraySizeInfo = arraySizeInfo with { ConstSize = (int)sizeConstArg.Value! }; + } + if (namedArguments.TryGetValue(nameof(MarshalAsAttribute.SizeParamIndex), out TypedConstant sizeParamIndexArg)) + { + if (!elementInfoProvider.TryGetInfoForParamIndex(attributeData, (short)sizeParamIndexArg.Value!, marshallingInfoCallback, out TypePositionInfo paramIndexInfo)) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(MarshalAsAttribute.SizeParamIndex), sizeParamIndexArg.Value.ToString()); + } + arraySizeInfo = arraySizeInfo with { ParamAtIndex = paramIndexInfo }; + } + return new UseSiteAttributeData(0, arraySizeInfo, attributeData); + } + + MarshallingInfo? IMarshallingInfoAttributeParser.ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + object unmanagedTypeObj = attributeData.ConstructorArguments[0].Value!; + UnmanagedType unmanagedType = unmanagedTypeObj is short unmanagedTypeAsShort + ? (UnmanagedType)unmanagedTypeAsShort + : (UnmanagedType)unmanagedTypeObj; + if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) + || unmanagedType == UnmanagedType.CustomMarshaler + || unmanagedType == UnmanagedType.SafeArray) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(UnmanagedType), unmanagedType.ToString()); + } + + bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; + UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize; + + // All other data on attribute is defined as NamedArguments. + foreach (KeyValuePair namedArg in attributeData.NamedArguments) + { + switch (namedArg.Key) + { + case nameof(MarshalAsAttribute.SafeArraySubType): + case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): + case nameof(MarshalAsAttribute.IidParameterIndex): + case nameof(MarshalAsAttribute.MarshalTypeRef): + case nameof(MarshalAsAttribute.MarshalType): + case nameof(MarshalAsAttribute.MarshalCookie): + _diagnostics.ReportConfigurationNotSupported(attributeData, $"{attributeData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + break; + case nameof(MarshalAsAttribute.ArraySubType): + if (!isArrayType) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, $"{attributeData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + } + elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; + break; + } + } + + if (isArrayType) + { + if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, nameof(UnmanagedType), unmanagedType.ToString()); + return NoMarshallingInfo.Instance; + } + + MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; + if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) + { + elementMarshallingInfo = elementType.SpecialType == SpecialType.System_String + ? CreateStringMarshallingInfo(elementType, elementUnmanagedType) + : new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); + } + else + { + elementMarshallingInfo = marshallingInfoCallback(elementType, useSiteAttributes, indirectionDepth + 1); + } + + CountInfo countInfo = NoCountInfo.Instance; + + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteAttributeData)) + { + countInfo = useSiteAttributeData.CountInfo; + } + + return ArrayMarshallingInfoProvider.CreateArrayMarshallingInfo(_compilation, type, elementType, countInfo, elementMarshallingInfo); + } + + if (type.SpecialType == SpecialType.System_String) + { + return CreateStringMarshallingInfo(type, unmanagedType); + } + + return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); + } + + private MarshallingInfo CreateStringMarshallingInfo( + ITypeSymbol type, + UnmanagedType unmanagedType) + { + string? marshallerName = unmanagedType switch + { + UnmanagedType.BStr => TypeNames.BStrStringMarshaller, + UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller, + UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller, + MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller, + _ => null + }; + + if (marshallerName is null) + return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); + + return StringMarshallingInfoProvider.CreateStringMarshallingInfo(_compilation, type, marshallerName); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs new file mode 100644 index 0000000000000000000000000000000000000000..14d60d3c9ce90f7c6f43f5ec20d6bb3f3d9dc24b --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalUsingAttributeParser.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// This class suppports parsing a System.Runtime.InteropServices.Marshalling.MarshalUsingAttribute. + /// + public sealed class MarshalUsingAttributeParser : IMarshallingInfoAttributeParser, IUseSiteAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + + public MarshalUsingAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics) + { + _compilation = compilation; + _diagnostics = diagnostics; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.MarshalUsingAttribute; + + MarshallingInfo? IMarshallingInfoAttributeParser.ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + Debug.Assert(attributeData.AttributeClass!.ToDisplayString() == TypeNames.MarshalUsingAttribute); + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + if (attributeData.ConstructorArguments.Length == 0) + { + // This attribute only has count information. + // It does not provide any marshalling info. + // Return null here to respresent the lack of any marshalling info, + // instead of the presence of invalid marshalling info. + return null; + } + + if (attributeData.ConstructorArguments[0].Value is not INamedTypeSymbol namedType) + { + return NoMarshallingInfo.Instance; + } + + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + namedType, + attributeData, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation + ); + } + + UseSiteAttributeData IUseSiteAttributeParser.ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + ImmutableDictionary namedArgs = ImmutableDictionary.CreateRange(attributeData.NamedArguments); + CountInfo countInfo = ParseCountInfo(attributeData, namedArgs, elementInfoProvider, marshallingInfoCallback); + int elementIndirectionDepth = namedArgs.TryGetValue(ManualTypeMarshallingHelper.MarshalUsingProperties.ElementIndirectionDepth, out TypedConstant value) ? (int)value.Value! : 0; + return new UseSiteAttributeData(elementIndirectionDepth, countInfo, attributeData); + } + + private CountInfo ParseCountInfo(AttributeData attributeData, ImmutableDictionary namedArguments, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback) + { + int? constSize = null; + string? elementName = null; + foreach (KeyValuePair arg in attributeData.NamedArguments) + { + if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ConstantElementCount) + { + constSize = (int)arg.Value.Value!; + } + else if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName) + { + if (arg.Value.Value is null) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, "null"); + return NoCountInfo.Instance; + } + elementName = (string)arg.Value.Value!; + } + } + + if (constSize is not null && elementName is not null) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attributeData, nameof(SR.ConstantAndElementCountInfoDisallowed)); + } + else if (constSize is not null) + { + return new ConstSizeCountInfo(constSize.Value); + } + else if (elementName is not null) + { + if (!elementInfoProvider.TryGetInfoForElementName(attributeData, elementName, marshallingInfoCallback, out TypePositionInfo elementInfo)) + { + _diagnostics.ReportConfigurationNotSupported(attributeData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, elementName); + return NoCountInfo.Instance; + } + return new CountElementCountInfo(elementInfo); + } + + return NoCountInfo.Instance; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index c245f8db61bcb19bc7c0b530ee86f1004f46d155..ab849630886ae1a97d2ed34e5412af56da6f2a4b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -281,13 +281,21 @@ public bool AnyIncomingEdge(int to) { if (nestedCollection.ElementCountInfo is CountElementCountInfo { ElementInfo: TypePositionInfo nestedCountElement }) { - yield return nestedCountElement; + // Do not include dependent elements with no managed or native index. + // These values are dummy values that are inserted earlier to avoid emitting extra diagnostics. + if (nestedCountElement.ManagedIndex != TypePositionInfo.UnsetIndex || nestedCountElement.NativeIndex != TypePositionInfo.UnsetIndex) + { + yield return nestedCountElement; + } } foreach (KeyValuePair mode in nestedCollection.Marshallers.Modes) { - foreach (TypePositionInfo nestedElements in GetDependentElementsOfMarshallingInfo(mode.Value.CollectionElementMarshallingInfo)) + foreach (TypePositionInfo nestedElement in GetDependentElementsOfMarshallingInfo(mode.Value.CollectionElementMarshallingInfo)) { - yield return nestedElements; + if (nestedElement.ManagedIndex != TypePositionInfo.UnsetIndex || nestedElement.NativeIndex != TypePositionInfo.UnsetIndex) + { + yield return nestedElement; + } } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 19680588e7bad41a72d696fce245cee05bcfcf78..027267cf373a11e297f2e3baeba7aa0996f83764 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -15,7 +15,7 @@ namespace Microsoft.Interop /// Type used to pass on default marshalling details. /// /// - /// This type used to pass default marshalling details to . + /// This type used to pass default marshalling details to the various marshalling info parsers. /// Since it contains a , it should not be used as a field on any types /// derived from . See remarks on . /// @@ -79,18 +79,6 @@ public enum CharEncoding CharEncoding CharEncoding ) : MarshallingInfo; - /// - /// Simple User-application of System.Runtime.InteropServices.MarshalAsAttribute - /// - public sealed record MarshalAsInfo( - UnmanagedType UnmanagedType, - CharEncoding CharEncoding) : MarshallingInfoStringSupport(CharEncoding) - { - // UnmanagedType.LPUTF8Str is not in netstandard2.0, so we define a constant for the value here. - // See https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.unmanagedtype - internal const UnmanagedType UnmanagedType_LPUTF8Str = (UnmanagedType)0x30; - } - /// /// The provided type was determined to be an "unmanaged" type that can be passed as-is to native code. /// @@ -145,11 +133,6 @@ bool IsStrictlyBlittable EntryPointType, Marshallers); - /// - /// The type of the element is a SafeHandle-derived type with no marshalling attributes. - /// - public sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo; - /// /// Marshalling information is lacking because of support not because it is /// unknown or non-existent. Includes information about element types in case @@ -160,720 +143,4 @@ bool IsStrictlyBlittable /// the forwarder marshaller. /// public sealed record MissingSupportCollectionMarshallingInfo(CountInfo CountInfo, MarshallingInfo ElementMarshallingInfo) : MissingSupportMarshallingInfo; - - public sealed class MarshallingAttributeInfoParser - { - private readonly Compilation _compilation; - private readonly IGeneratorDiagnostics _diagnostics; - private readonly DefaultMarshallingInfo _defaultInfo; - private readonly ISymbol _contextSymbol; - private readonly ITypeSymbol _marshalAsAttribute; - private readonly ITypeSymbol _marshalUsingAttribute; - - public MarshallingAttributeInfoParser( - Compilation compilation, - IGeneratorDiagnostics diagnostics, - DefaultMarshallingInfo defaultInfo, - ISymbol contextSymbol) - { - _compilation = compilation; - _diagnostics = diagnostics; - _defaultInfo = defaultInfo; - _contextSymbol = contextSymbol; - _marshalAsAttribute = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)!; - _marshalUsingAttribute = compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute)!; - } - - public MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes) - { - return ParseMarshallingInfo(managedType, useSiteAttributes, ImmutableHashSet.Empty); - } - - private MarshallingInfo ParseMarshallingInfo( - ITypeSymbol managedType, - IEnumerable useSiteAttributes, - ImmutableHashSet inspectedElements) - { - Dictionary marshallingAttributesByIndirectionDepth = new(); - int maxIndirectionLevelDataProvided = 0; - foreach (AttributeData attribute in useSiteAttributes) - { - if (TryGetAttributeIndirectionLevel(attribute, out int indirectionLevel)) - { - if (marshallingAttributesByIndirectionDepth.ContainsKey(indirectionLevel)) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attribute, nameof(SR.DuplicateMarshallingInfo), indirectionLevel.ToString()); - return NoMarshallingInfo.Instance; - } - marshallingAttributesByIndirectionDepth.Add(indirectionLevel, attribute); - maxIndirectionLevelDataProvided = Math.Max(maxIndirectionLevelDataProvided, indirectionLevel); - } - } - - int maxIndirectionDepthUsed = 0; - MarshallingInfo info = GetMarshallingInfo( - managedType, - marshallingAttributesByIndirectionDepth, - indirectionLevel: 0, - inspectedElements, - ref maxIndirectionDepthUsed); - if (maxIndirectionDepthUsed < maxIndirectionLevelDataProvided) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo( - marshallingAttributesByIndirectionDepth[maxIndirectionLevelDataProvided], - nameof(SR.ExtraneousMarshallingInfo), - maxIndirectionLevelDataProvided.ToString(), - maxIndirectionDepthUsed.ToString()); - } - return info; - } - - private MarshallingInfo GetMarshallingInfo( - ITypeSymbol type, - Dictionary useSiteAttributes, - int indirectionLevel, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - maxIndirectionDepthUsed = Math.Max(indirectionLevel, maxIndirectionDepthUsed); - CountInfo parsedCountInfo = NoCountInfo.Instance; - - if (useSiteAttributes.TryGetValue(indirectionLevel, out AttributeData useSiteAttribute)) - { - INamedTypeSymbol attributeClass = useSiteAttribute.AttributeClass!; - - if (indirectionLevel == 0 - && SymbolEqualityComparer.Default.Equals(_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) - { - // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateInfoFromMarshalAs(type, useSiteAttribute, inspectedElements, ref maxIndirectionDepthUsed); - } - else if (SymbolEqualityComparer.Default.Equals(_compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) - { - if (parsedCountInfo != NoCountInfo.Instance) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(useSiteAttribute, nameof(SR.DuplicateCountInfo)); - return NoMarshallingInfo.Instance; - } - parsedCountInfo = CreateCountInfo(useSiteAttribute, inspectedElements); - if (useSiteAttribute.ConstructorArguments.Length != 0) - { - return CreateNativeMarshallingInfo( - type, - (INamedTypeSymbol)useSiteAttribute.ConstructorArguments[0].Value!, - useSiteAttribute, - isMarshalUsingAttribute: true, - indirectionLevel, - parsedCountInfo, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed); - } - } - } - - // If we aren't overriding the marshalling at usage time, - // then fall back to the information on the element type itself. - foreach (AttributeData typeAttribute in type.GetAttributes()) - { - INamedTypeSymbol attributeClass = typeAttribute.AttributeClass!; - - if (attributeClass.ToDisplayString() == TypeNames.NativeMarshallingAttribute) - { - return CreateNativeMarshallingInfo( - type, - (INamedTypeSymbol)typeAttribute.ConstructorArguments[0].Value!, - typeAttribute, - isMarshalUsingAttribute: false, - indirectionLevel, - parsedCountInfo, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed); - } - } - - // If the type doesn't have custom attributes that dictate marshalling, - // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo( - type, - parsedCountInfo, - indirectionLevel, - useSiteAttributes, - inspectedElements, - ref maxIndirectionDepthUsed, - out MarshallingInfo infoMaybe)) - { - return infoMaybe; - } - - return NoMarshallingInfo.Instance; - } - - private CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashSet inspectedElements) - { - int? constSize = null; - string? elementName = null; - foreach (KeyValuePair arg in marshalUsingData.NamedArguments) - { - if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ConstantElementCount) - { - constSize = (int)arg.Value.Value!; - } - else if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName) - { - if (arg.Value.Value is null) - { - _diagnostics.ReportConfigurationNotSupported(marshalUsingData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, "null"); - return NoCountInfo.Instance; - } - elementName = (string)arg.Value.Value!; - } - } - - if (constSize is not null && elementName is not null) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(marshalUsingData, nameof(SR.ConstantAndElementCountInfoDisallowed)); - } - else if (constSize is not null) - { - return new ConstSizeCountInfo(constSize.Value); - } - else if (elementName is not null) - { - if (inspectedElements.Contains(elementName)) - { - throw new CyclicalCountElementInfoException(inspectedElements, elementName); - } - - try - { - TypePositionInfo? elementInfo = CreateForElementName(elementName, inspectedElements.Add(elementName)); - if (elementInfo is null) - { - _diagnostics.ReportConfigurationNotSupported(marshalUsingData, ManualTypeMarshallingHelper.MarshalUsingProperties.CountElementName, elementName); - return NoCountInfo.Instance; - } - return new CountElementCountInfo(elementInfo); - } - // Specifically catch the exception when we're trying to inspect the element that started the cycle. - // This ensures that we've unwound the whole cycle so when we return NoCountInfo.Instance, there will be no cycles in the count info. - catch (CyclicalCountElementInfoException ex) when (ex.StartOfCycle == elementName) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(marshalUsingData, nameof(SR.CyclicalCountInfo), elementName); - return NoCountInfo.Instance; - } - } - - return NoCountInfo.Instance; - } - - private TypePositionInfo? CreateForParamIndex(AttributeData attrData, int paramIndex, ImmutableHashSet inspectedElements) - { - if (!(_contextSymbol is IMethodSymbol method && 0 <= paramIndex && paramIndex < method.Parameters.Length)) - { - return null; - } - IParameterSymbol param = method.Parameters[paramIndex]; - - if (inspectedElements.Contains(param.Name)) - { - throw new CyclicalCountElementInfoException(inspectedElements, param.Name); - } - - try - { - return TypePositionInfo.CreateForParameter( - param, - ParseMarshallingInfo(param.Type, param.GetAttributes(), inspectedElements.Add(param.Name)), _compilation) with - { ManagedIndex = paramIndex }; - } - // Specifically catch the exception when we're trying to inspect the element that started the cycle. - // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. - catch (CyclicalCountElementInfoException ex) when (ex.StartOfCycle == param.Name) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), param.Name); - return SizeAndParamIndexInfo.UnspecifiedParam; - } - } - - private TypePositionInfo? CreateForElementName(string elementName, ImmutableHashSet inspectedElements) - { - if (_contextSymbol is IMethodSymbol method) - { - if (elementName == CountElementCountInfo.ReturnValueElementName) - { - return new TypePositionInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), - ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), inspectedElements)) with - { - ManagedIndex = TypePositionInfo.ReturnIndex - }; - } - - for (int i = 0; i < method.Parameters.Length; i++) - { - IParameterSymbol param = method.Parameters[i]; - if (param.Name == elementName) - { - return TypePositionInfo.CreateForParameter(param, ParseMarshallingInfo(param.Type, param.GetAttributes(), inspectedElements), _compilation) with { ManagedIndex = i }; - } - } - } - else if (_contextSymbol is INamedTypeSymbol) - { - // TODO: Handle when we create a struct marshalling generator - // Do we want to support CountElementName pointing to only fields, or properties as well? - // If only fields, how do we handle properties with generated backing fields? - } - - return null; - } - - private MarshallingInfo CreateInfoFromMarshalAs( - ITypeSymbol type, - AttributeData attrData, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; - UnmanagedType unmanagedType = unmanagedTypeObj is short unmanagedTypeAsShort - ? (UnmanagedType)unmanagedTypeAsShort - : (UnmanagedType)unmanagedTypeObj; - if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) - || unmanagedType == UnmanagedType.CustomMarshaler - || unmanagedType == UnmanagedType.SafeArray) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - } - - bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; - UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize; - SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; - - // All other data on attribute is defined as NamedArguments. - foreach (KeyValuePair namedArg in attrData.NamedArguments) - { - switch (namedArg.Key) - { - default: - Debug.Fail($"An unknown member was found on {nameof(MarshalAsAttribute)}"); - continue; - case nameof(MarshalAsAttribute.SafeArraySubType): - case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): - case nameof(MarshalAsAttribute.IidParameterIndex): - case nameof(MarshalAsAttribute.MarshalTypeRef): - case nameof(MarshalAsAttribute.MarshalType): - case nameof(MarshalAsAttribute.MarshalCookie): - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - break; - case nameof(MarshalAsAttribute.ArraySubType): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; - break; - case nameof(MarshalAsAttribute.SizeConst): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - arraySizeInfo = arraySizeInfo with { ConstSize = (int)namedArg.Value.Value! }; - break; - case nameof(MarshalAsAttribute.SizeParamIndex): - if (!isArrayType) - { - _diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - TypePositionInfo? paramIndexInfo = CreateForParamIndex(attrData, (short)namedArg.Value.Value!, inspectedElements); - - if (paramIndexInfo is null) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(MarshalAsAttribute.SizeParamIndex), namedArg.Value.Value.ToString()); - } - arraySizeInfo = arraySizeInfo with { ParamAtIndex = paramIndexInfo }; - break; - } - } - - if (isArrayType) - { - if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - _diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - return NoMarshallingInfo.Instance; - } - - MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; - if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedConstSize) - { - elementMarshallingInfo = elementType.SpecialType == SpecialType.System_String - ? CreateStringMarshallingInfo(elementType, elementUnmanagedType) - : new MarshalAsInfo(elementUnmanagedType, _defaultInfo.CharEncoding); - } - else - { - maxIndirectionDepthUsed = 1; - elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsed); - } - - return CreateArrayMarshallingInfo(type, elementType, arraySizeInfo, elementMarshallingInfo); - } - - if (type.SpecialType == SpecialType.System_String) - { - return CreateStringMarshallingInfo(type, unmanagedType); - } - - return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - } - - private MarshallingInfo CreateNativeMarshallingInfo( - ITypeSymbol type, - INamedTypeSymbol entryPointType, - AttributeData attrData, - bool isMarshalUsingAttribute, - int indirectionLevel, - CountInfo parsedCountInfo, - Dictionary useSiteAttributes, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed) - { - if (!ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) - { - return NoMarshallingInfo.Instance; - } - - if (!(entryPointType.IsStatic && entryPointType.TypeKind == TypeKind.Class) - && entryPointType.TypeKind != TypeKind.Struct) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerTypeMustBeStaticClassOrStruct), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - ManagedTypeInfo entryPointTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType); - - bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); - if (isLinearCollectionMarshalling) - { - // Update the entry point type with the type arguments based on the managed type - if (type is IArrayTypeSymbol arrayManagedType) - { - if (entryPointType.Arity != 2) - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - entryPointType = entryPointType.ConstructedFrom.Construct( - arrayManagedType.ElementType, - entryPointType.TypeArguments.Last()); - } - else if (type is INamedTypeSymbol namedManagedCollectionType && entryPointType.IsUnboundGenericType) - { - if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( - namedManagedCollectionType, - entryPointType, - isLinearCollectionMarshalling, - (type, entryPointType) => _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), - out ITypeSymbol resolvedEntryPointType)) - { - return NoMarshallingInfo.Instance; - } - - entryPointType = (INamedTypeSymbol)resolvedEntryPointType; - } - else - { - _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed; - Func getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsedLocal); - if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers)) - { - maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal; - return new NativeLinearCollectionMarshallingInfo( - entryPointTypeInfo, - collectionMarshallers.Value, - parsedCountInfo, - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType.TypeParameters.Last())); - } - return NoMarshallingInfo.Instance; - } - - if (type is INamedTypeSymbol namedManagedType && entryPointType.IsUnboundGenericType) - { - if (!ManualTypeMarshallingHelper.TryResolveEntryPointType( - namedManagedType, - entryPointType, - isLinearCollectionMarshalling, - (type, entryPointType) => _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()), - out ITypeSymbol resolvedEntryPointType)) - { - return NoMarshallingInfo.Instance; - } - - entryPointType = (INamedTypeSymbol)resolvedEntryPointType; - } - - if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(entryPointType, type, _compilation, out CustomTypeMarshallers? marshallers)) - { - return new NativeMarshallingAttributeInfo(entryPointTypeInfo, marshallers.Value); - } - return NoMarshallingInfo.Instance; - } - - private bool TryCreateTypeBasedMarshallingInfo( - ITypeSymbol type, - CountInfo parsedCountInfo, - int indirectionLevel, - Dictionary useSiteAttributes, - ImmutableHashSet inspectedElements, - ref int maxIndirectionDepthUsed, - out MarshallingInfo marshallingInfo) - { - // Check for an implicit SafeHandle conversion. - // The SafeHandle type might not be defined if we're using one of the test CoreLib implementations used for NativeAOT. - ITypeSymbol? safeHandleType = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle); - if (safeHandleType is not null) - { - CodeAnalysis.Operations.CommonConversion conversion = _compilation.ClassifyCommonConversion(type, safeHandleType); - if (conversion.Exists - && conversion.IsImplicit - && (conversion.IsReference || conversion.IsIdentity)) - { - bool hasAccessibleDefaultConstructor = false; - if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) - { - foreach (IMethodSymbol ctor in named.InstanceConstructors) - { - if (ctor.Parameters.Length == 0) - { - hasAccessibleDefaultConstructor = _compilation.IsSymbolAccessibleWithin(ctor, _contextSymbol.ContainingType); - break; - } - } - } - marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract); - return true; - } - } - - if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - MarshallingInfo elementMarshallingInfo = GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsed); - marshallingInfo = CreateArrayMarshallingInfo(type, elementType, parsedCountInfo, elementMarshallingInfo); - return true; - } - - // No marshalling info was computed, but a character encoding was provided. - // If the type is a character or string then pass on these details. - if (type.SpecialType == SpecialType.System_Char && _defaultInfo.CharEncoding != CharEncoding.Undefined) - { - marshallingInfo = new MarshallingInfoStringSupport(_defaultInfo.CharEncoding); - return true; - } - - if (type.SpecialType == SpecialType.System_String && _defaultInfo.CharEncoding != CharEncoding.Undefined) - { - if (_defaultInfo.CharEncoding == CharEncoding.Custom) - { - if (_defaultInfo.StringMarshallingCustomType is not null) - { - AttributeData attrData = _contextSymbol is IMethodSymbol - ? _contextSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass.ToDisplayString() == TypeNames.LibraryImportAttribute) - : default; - marshallingInfo = CreateNativeMarshallingInfo(type, _defaultInfo.StringMarshallingCustomType, attrData, true, indirectionLevel, parsedCountInfo, useSiteAttributes, inspectedElements, ref maxIndirectionDepthUsed); - return true; - } - } - else - { - marshallingInfo = _defaultInfo.CharEncoding switch - { - CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller), - CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller), - _ => throw new InvalidOperationException() - }; - - return true; - } - - marshallingInfo = new MarshallingInfoStringSupport(_defaultInfo.CharEncoding); - return true; - } - - - if (type.SpecialType == SpecialType.System_Boolean) - { - // We explicitly don't support marshalling bool without any marshalling info - // as treating bool as a non-normalized 1-byte value is generally not a good default. - // Additionally, that default is different than the runtime marshalling, so by explicitly - // blocking bool marshalling without additional info, we make it a little easier - // to transition by explicitly notifying people of changing behavior. - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - - if (type is INamedTypeSymbol { IsUnmanagedType: true } unmanagedType - && unmanagedType.IsConsideredBlittable()) - { - marshallingInfo = GetBlittableMarshallingInfo(type); - return true; - } - - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - - private MarshallingInfo CreateArrayMarshallingInfo( - ITypeSymbol managedType, - ITypeSymbol elementType, - CountInfo countInfo, - MarshallingInfo elementMarshallingInfo) - { - ITypeSymbol typeArgumentToInsert = elementType; - INamedTypeSymbol? arrayMarshaller; - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_PointerArrayMarshaller_Metadata); - typeArgumentToInsert = pointedAt; - } - else - { - arrayMarshaller = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_ArrayMarshaller_Metadata); - } - - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays but indicate it is missing. - return new MissingSupportCollectionMarshallingInfo(countInfo, elementMarshallingInfo); - } - - if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(arrayMarshaller) - && ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(arrayMarshaller)) - { - arrayMarshaller = arrayMarshaller.Construct( - typeArgumentToInsert, - arrayMarshaller.TypeArguments.Last()); - - Func getMarshallingInfoForElement = (ITypeSymbol elementType) => elementMarshallingInfo; - if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(arrayMarshaller, managedType, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers)) - { - return new NativeLinearCollectionMarshallingInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), - marshallers.Value, - countInfo, - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller.TypeParameters.Last())); - } - } - - Debug.WriteLine("Default marshallers for arrays should be a valid shape."); - return NoMarshallingInfo.Instance; - } - - private MarshallingInfo CreateStringMarshallingInfo( - ITypeSymbol type, - UnmanagedType unmanagedType) - { - string? marshallerName = unmanagedType switch - { - UnmanagedType.BStr => TypeNames.BStrStringMarshaller, - UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller, - UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller, - MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller, - _ => null - }; - - if (marshallerName is null) - return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - - return CreateStringMarshallingInfo(type, marshallerName); - } - - private MarshallingInfo CreateStringMarshallingInfo( - ITypeSymbol type, - string marshallerName) - { - INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName); - if (stringMarshaller is null) - return new MissingSupportMarshallingInfo(); - - if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(stringMarshaller)) - { - if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(stringMarshaller, type, _compilation, out CustomTypeMarshallers? marshallers)) - { - return new NativeMarshallingAttributeInfo( - EntryPointType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(stringMarshaller), - Marshallers: marshallers.Value); - } - } - - return new MissingSupportMarshallingInfo(); - } - - private MarshallingInfo GetBlittableMarshallingInfo(ITypeSymbol type) - { - if (type.TypeKind is TypeKind.Enum or TypeKind.Pointer or TypeKind.FunctionPointer - || type.SpecialType.IsAlwaysBlittable()) - { - // Treat primitive types and enums as having no marshalling info. - // They are supported in configurations where runtime marshalling is enabled. - return NoMarshallingInfo.Instance; - } - else if (_compilation.GetTypeByMetadataName(TypeNames.System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute) is null) - { - // If runtime marshalling cannot be disabled, then treat this as a "missing support" scenario so we can gracefully fall back to using the forwarder downlevel. - return new MissingSupportMarshallingInfo(); - } - else - { - return new UnmanagedBlittableMarshallingInfo(type.IsStrictlyBlittable()); - } - } - - private bool TryGetAttributeIndirectionLevel(AttributeData attrData, out int indirectionLevel) - { - if (SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, _marshalAsAttribute)) - { - indirectionLevel = 0; - return true; - } - - if (!SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, _marshalUsingAttribute)) - { - indirectionLevel = 0; - return false; - } - - foreach (KeyValuePair arg in attrData.NamedArguments) - { - if (arg.Key == ManualTypeMarshallingHelper.MarshalUsingProperties.ElementIndirectionDepth) - { - indirectionLevel = (int)arg.Value.Value!; - return true; - } - } - indirectionLevel = 0; - return true; - } - - private sealed class CyclicalCountElementInfoException : Exception - { - public CyclicalCountElementInfoException(ImmutableHashSet elementsInCycle, string startOfCycle) - { - ElementsInCycle = elementsInCycle; - StartOfCycle = startOfCycle; - } - - public ImmutableHashSet ElementsInCycle { get; } - - public string StartOfCycle { get; } - } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs new file mode 100644 index 0000000000000000000000000000000000000000..410166893a514a096c5aa8012ef4bf2df2f57395 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs @@ -0,0 +1,367 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Information from a marshalling attribute that can only be provided if the attribute is at the usage site. + /// + /// The indirection depth that the info applies to. + /// Any collection count information provided. + /// The original attribute data. + public sealed record UseSiteAttributeData(int IndirectionDepth, CountInfo CountInfo, AttributeData AttributeData); + + /// + /// A callback to get the marshalling info for a given type at the provided indirection depth with the provided attributes at its usage site. + /// + /// The managed type to get marshalling info for + /// The attributes at the use site + /// The target indirection level + /// Marshalling info for provided information. + public delegate MarshallingInfo GetMarshallingInfoCallback(ITypeSymbol type, UseSiteAttributeProvider useSiteAttributes, int indirectionDepth); + + /// + /// A parser for an attribute used at the marshalling site, such as a parameter or return value attribute. + /// + public interface IUseSiteAttributeParser + { + /// + /// Whether or not the parser can parse an attribute of the provided type. + /// + /// The attribute type + /// true if the parser can parse an attribute of the provided type; otherwise false + bool CanParseAttributeType(INamedTypeSymbol attributeType); + + /// + /// Parse the use-site information out of the provided attribute. + /// + /// The attribute data to parse. + /// The provider for information about other elements. This is used to retrieve information about other parameters that might be referenced by any count information. + /// A callback to provide to the when retrieving additional information. + /// The information about the attribute at the use site. + UseSiteAttributeData ParseAttribute(AttributeData attributeData, IElementInfoProvider elementInfoProvider, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A parser for an attribute that provides information about which marshaller to use. + /// + public interface IMarshallingInfoAttributeParser + { + /// + /// Whether or not the parser can parse an attribute of the provided type. + /// + /// The attribute type + /// true if the parser can parse an attribute of the provided type; otherwise false + bool CanParseAttributeType(INamedTypeSymbol attributeType); + + /// + /// Parse the attribute into marshalling information + /// + /// The attribute to parse + /// The managed type + /// The current indirection depth + /// Attributes provided at the usage site, such as for count information + /// A callback to get marshalling info for nested elements, in the case of a collection of collections. + /// Marshalling information parsed from the attribute, or null if no information could be parsed from the attribute + MarshallingInfo? ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A provider of marshalling info based only on the managed type any any previously parsed use-site attribute information + /// + public interface ITypeBasedMarshallingInfoProvider + { + /// + /// Whether or not the provider can provide marshalling info for the given managed type. + /// + /// The managed type + /// true if the provider can provide info for the provided type; otherwise false + bool CanProvideMarshallingInfoForType(ITypeSymbol type); + /// + /// Get marshalling info for the provided type at the given indirection level. + /// + /// The managed type + /// The current indirection depth + /// Attributes provided at the usage site, such as for count information + /// A callback to get marshalling info for nested elements, in the case of a collection of collections. + /// Marshalling information for the provided type + MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback); + } + + /// + /// A provider of TypePositionInfo instances for other elements in the current signature, such as parameters or return values. + /// + public interface IElementInfoProvider + { + /// + /// Get the name for a given index. + /// + /// The index + /// The name associated with the provided index, or if the index does not correspond to an element. + string FindNameForParamIndex(int paramIndex); + /// + /// Get a instance for the given element name. + /// + /// The attribute to report diagnostics on. + /// The element name to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The to pass to the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info); + /// + /// Get a instance for the given element index. + /// + /// The attribute to report diagnostics on. + /// The element index to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The to pass to the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info); + } + + /// + /// Convenience extension methods for . + /// + public static class ElementInfoProviderExtensions + { + /// + /// Get a instance for the given element name. + /// + /// The attribute to report diagnostics on. + /// The element name to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + public static bool TryGetInfoForElementName(this IElementInfoProvider provider, AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, out TypePositionInfo info) + { + return provider.TryGetInfoForElementName(attrData, elementName, marshallingInfoCallback, provider, out info); + } + + /// + /// Get a instance for the given element index. + /// + /// The attribute to report diagnostics on. + /// The element index to retrieve a instance for. + /// A callback to retrieve marshalling info to put into the . + /// The produced info. + /// true if a instance could be created for the type; otherwise false + public static bool TryGetInfoForParamIndex(this IElementInfoProvider provider, AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, out TypePositionInfo info) + { + return provider.TryGetInfoForParamIndex(attrData, paramIndex, marshallingInfoCallback, provider, out info); + } + } + + /// + /// Overall parser for all marshalling info. + /// + /// + /// This type combines the provided parsers to enable parsing marshalling information from a single call. + /// For a given managed type and use site attributes, it will parse marshalling information in the following order: + /// + /// Parse attributes provided at the usage site for use-site-only information. + /// Parse attributes provided at the usage site for marshalling information. + /// If no marshalling information has been found yet, parse attributes provided at the definition of the managed type. + /// If no marshalling information has been found yet, generate marshalling information for the managed type itself. + /// + /// + public sealed class MarshallingInfoParser + { + private readonly IGeneratorDiagnostics _diagnostics; + private readonly IElementInfoProvider _elementInfoProvider; + private readonly ImmutableArray _useSiteMarshallingAttributeParsers; + private readonly ImmutableArray _marshallingAttributeParsers; + private readonly ImmutableArray _typeBasedMarshallingInfoProviders; + + /// + /// Construct a new . + /// + /// The diagnostics sink to report all diagnostics to. + /// An to retrieve information about other elements than the current element when parsing. + /// Parsers for retrieving use-site-only information from attributes. + /// Parsers for retrieving marshalling information from attributes and the managed type. + /// Parsers for retrieving marshalling information from the managed type only. + public MarshallingInfoParser( + IGeneratorDiagnostics diagnostics, + IElementInfoProvider elementInfoProvider, + ImmutableArray useSiteMarshallingAttributeParsers, + ImmutableArray marshallingAttributeParsers, + ImmutableArray typeBasedMarshallingInfoProviders) + { + _diagnostics = diagnostics; + // Always support cycle detection. Otherwise we can get stack-overflows, which does not provide a good dev experience for any customer scenario. + _elementInfoProvider = new CycleDetectingElementInfoProvider(elementInfoProvider, diagnostics); + _useSiteMarshallingAttributeParsers = useSiteMarshallingAttributeParsers; + _marshallingAttributeParsers = marshallingAttributeParsers; + _typeBasedMarshallingInfoProviders = typeBasedMarshallingInfoProviders; + } + + /// + /// Parse the marshalling info for the provided managed type and attributes at the usage site. + /// + /// The managed type + /// All attributes specified at the usage site + /// The parsed marshalling information + public MarshallingInfo ParseMarshallingInfo( + ITypeSymbol managedType, + IEnumerable useSiteAttributes) + { + UseSiteAttributeProvider useSiteAttributeProvider = new UseSiteAttributeProvider(_useSiteMarshallingAttributeParsers, useSiteAttributes, _elementInfoProvider, _diagnostics, GetMarshallingInfo); + + MarshallingInfo info = GetMarshallingInfo( + managedType, + useSiteAttributeProvider, + indirectionDepth: 0); + + useSiteAttributeProvider.OnAttributeUsageFinished(); + return info; + } + + private MarshallingInfo GetMarshallingInfo( + ITypeSymbol type, + UseSiteAttributeProvider useSiteAttributes, + int indirectionDepth) + { + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out UseSiteAttributeData useSiteAttribute)) + { + if (GetMarshallingInfoForAttribute(useSiteAttribute.AttributeData, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo) + { + return marshallingInfo; + } + } + + // If we aren't overriding the marshalling at usage time, + // then fall back to the information on the element type itself. + foreach (AttributeData typeAttribute in type.GetAttributes()) + { + if (GetMarshallingInfoForAttribute(typeAttribute, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo) + { + return marshallingInfo; + } + } + + // If the type doesn't have custom attributes that dictate marshalling, + // then consider the type itself. + return GetMarshallingInfoForType(type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) ?? NoMarshallingInfo.Instance; + } + + private MarshallingInfo? GetMarshallingInfoForAttribute(AttributeData attribute, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + foreach (var parser in _marshallingAttributeParsers) + { + // Automatically ignore invalid attributes. + // The compiler will already error on them. + if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass)) + { + return parser.ParseAttribute(attribute, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } + } + return null; + } + + private MarshallingInfo? GetMarshallingInfoForType(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + foreach (var parser in _typeBasedMarshallingInfoProviders) + { + if (parser.CanProvideMarshallingInfoForType(type)) + { + return parser.GetMarshallingInfo(type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } + } + return null; + } + } + + /// + /// Wraps another with support to detect infinite cycles in marshalling info (i.e. count information that refers to other elements that refer to the original element). + /// + internal sealed class CycleDetectingElementInfoProvider : IElementInfoProvider + { + private ImmutableHashSet _activeInspectingElements = ImmutableHashSet.Empty; + private readonly IElementInfoProvider _innerProvider; + private readonly IGeneratorDiagnostics _diagnostics; + + public CycleDetectingElementInfoProvider(IElementInfoProvider innerProvider, IGeneratorDiagnostics diagnostics) + { + _innerProvider = innerProvider; + _diagnostics = diagnostics; + } + + public string FindNameForParamIndex(int paramIndex) => _innerProvider.FindNameForParamIndex(paramIndex); + public bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, [NotNullWhen(true)] out TypePositionInfo? info) + { + ImmutableHashSet inspectedElements = _activeInspectingElements; + if (inspectedElements.Contains(elementName)) + { + throw new CyclicalElementInfoException(inspectedElements, elementName); + } + try + { + _activeInspectingElements = inspectedElements.Add(elementName); + return _innerProvider.TryGetInfoForElementName(attrData, elementName, marshallingInfoCallback, rootProvider, out info); + } + // Specifically catch the exception when we're trying to inspect the element that started the cycle. + // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. + catch (CyclicalElementInfoException ex) when (ex.StartOfCycle == elementName) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), elementName); + // Create a dummy value for the invalid marshalling. We're already in an error state, so try to not report extraneous diagnostics. + info = new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance); + return true; + } + finally + { + _activeInspectingElements = inspectedElements; + } + } + + public bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, [NotNullWhen(true)] out TypePositionInfo? info) + { + ImmutableHashSet inspectedElements = _activeInspectingElements; + string paramName = _innerProvider.FindNameForParamIndex(paramIndex); + if (paramName is not null && inspectedElements.Contains(paramName)) + { + throw new CyclicalElementInfoException(inspectedElements, paramName); + } + + try + { + _activeInspectingElements = inspectedElements.Add(paramName); + return _innerProvider.TryGetInfoForParamIndex(attrData, paramIndex, marshallingInfoCallback, rootProvider, out info); + } + // Specifically catch the exception when we're trying to inspect the element that started the cycle. + // This ensures that we've unwound the whole cycle so when we return, there will be no cycles in the count info. + catch (CyclicalElementInfoException ex) when (ex.StartOfCycle == paramName) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.CyclicalCountInfo), paramName); + // Create a dummy value for the invalid marshalling. We're already in an error state, so try to not report extraneous diagnostics. + info = new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance); + return true; + } + finally + { + _activeInspectingElements = inspectedElements; + } + } + + private sealed class CyclicalElementInfoException : Exception + { + public CyclicalElementInfoException(ImmutableHashSet elementsInCycle, string startOfCycle) + { + ElementsInCycle = elementsInCycle; + StartOfCycle = startOfCycle; + } + + public ImmutableHashSet ElementsInCycle { get; } + + public string StartOfCycle { get; } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..e00eef6d2c0aea64b424c6df2e9b3652c2706db1 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MethodSignatureElementInfoProvider.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + public sealed class MethodSignatureElementInfoProvider : IElementInfoProvider + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _generatorDiagnostics; + private readonly IMethodSymbol _method; + private readonly ImmutableArray _useSiteAttributeParsers; + + public MethodSignatureElementInfoProvider(Compilation compilation, IGeneratorDiagnostics generatorDiagnostics, IMethodSymbol method, ImmutableArray useSiteAttributeParsers) + { + _compilation = compilation; + _generatorDiagnostics = generatorDiagnostics; + _method = method; + _useSiteAttributeParsers = useSiteAttributeParsers; + } + + public string FindNameForParamIndex(int paramIndex) => paramIndex >= _method.Parameters.Length ? string.Empty : _method.Parameters[paramIndex].Name; + + public bool TryGetInfoForElementName(AttributeData attrData, string elementName, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info) + { + if (elementName == CountElementCountInfo.ReturnValueElementName) + { + info = new TypePositionInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(_method.ReturnType), + marshallingInfoCallback(_method.ReturnType, new UseSiteAttributeProvider(_useSiteAttributeParsers, _method.GetReturnTypeAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0)) with + { + ManagedIndex = TypePositionInfo.ReturnIndex + }; + return true; + } + + for (int i = 0; i < _method.Parameters.Length; i++) + { + IParameterSymbol param = _method.Parameters[i]; + if (param.Name == elementName) + { + info = TypePositionInfo.CreateForParameter( + param, + marshallingInfoCallback(param.Type, new UseSiteAttributeProvider(_useSiteAttributeParsers, param.GetAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0), _compilation) with + { + ManagedIndex = i + }; + return true; + } + } + info = null; + return false; + } + + public bool TryGetInfoForParamIndex(AttributeData attrData, int paramIndex, GetMarshallingInfoCallback marshallingInfoCallback, IElementInfoProvider rootProvider, out TypePositionInfo info) + { + if (paramIndex >= _method.Parameters.Length) + { + info = null; + return false; + } + IParameterSymbol param = _method.Parameters[paramIndex]; + + info = TypePositionInfo.CreateForParameter( + param, + marshallingInfoCallback(param.Type, new UseSiteAttributeProvider(_useSiteAttributeParsers, param.GetAttributes(), rootProvider, _generatorDiagnostics, marshallingInfoCallback), 0), _compilation) with + { + ManagedIndex = paramIndex + }; + return true; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs new file mode 100644 index 0000000000000000000000000000000000000000..c51116029143e7f05196a7eb89e0d35f4018789a --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeMarshallingAttributeParser.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + public sealed class NativeMarshallingAttributeParser : IMarshallingInfoAttributeParser + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + + public NativeMarshallingAttributeParser(Compilation compilation, IGeneratorDiagnostics diagnostics) + { + _compilation = compilation; + _diagnostics = diagnostics; + } + + public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.NativeMarshallingAttribute; + + public MarshallingInfo? ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + Debug.Assert(attributeData.AttributeClass!.ToDisplayString() == TypeNames.NativeMarshallingAttribute); + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out var useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + + if (attributeData.ConstructorArguments[0].Value is not INamedTypeSymbol entryPointType) + { + return NoMarshallingInfo.Instance; + } + + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + entryPointType, + attributeData, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation + ); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..0309793096e7c41adc2f624a189ebbc857de2ef1 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SafeHandleMarshallingInfoProvider.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// The type of the element is a SafeHandle-derived type with no marshalling attributes. + /// + public sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo; + + /// + /// This class supports generating marshalling info for SafeHandle-derived types. + /// + public sealed class SafeHandleMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + private readonly ITypeSymbol _containingScope; + + public SafeHandleMarshallingInfoProvider(Compilation compilation, ITypeSymbol containingScope) + { + _compilation = compilation; + _containingScope = containingScope; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) + { + // Check for an implicit SafeHandle conversion. + // The SafeHandle type might not be defined if we're using one of the test CoreLib implementations used for NativeAOT. + ITypeSymbol? safeHandleType = _compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle); + if (safeHandleType is not null) + { + CodeAnalysis.Operations.CommonConversion conversion = _compilation.ClassifyCommonConversion(type, safeHandleType); + if (conversion.Exists + && conversion.IsImplicit + && (conversion.IsReference || conversion.IsIdentity)) + { + return true; + } + } + return false; + } + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + bool hasAccessibleDefaultConstructor = false; + if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) + { + foreach (IMethodSymbol ctor in named.InstanceConstructors) + { + if (ctor.Parameters.Length == 0) + { + hasAccessibleDefaultConstructor = _compilation.IsSymbolAccessibleWithin(ctor, _containingScope); + break; + } + } + } + return new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs index da94ec9a541d263661a5207b426232521c3dd4cc..b9df1831512e7c19ee1b43111e01c36021858833 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs @@ -56,9 +56,10 @@ public IEnumerable StubParameters InteropAttributeData interopAttributeData, StubEnvironment env, IGeneratorDiagnostics diagnostics, + AttributeData signatureWideMarshallingAttributeData, Assembly generatorInfoAssembly) { - ImmutableArray typeInfos = GenerateTypeInformation(method, interopAttributeData, diagnostics, env); + ImmutableArray typeInfos = GenerateTypeInformation(method, interopAttributeData, diagnostics, env, signatureWideMarshallingAttributeData); ImmutableArray.Builder additionalAttrs = ImmutableArray.CreateBuilder(); @@ -99,7 +100,12 @@ public IEnumerable StubParameters }; } - private static ImmutableArray GenerateTypeInformation(IMethodSymbol method, InteropAttributeData interopAttributeData, IGeneratorDiagnostics diagnostics, StubEnvironment env) + private static ImmutableArray GenerateTypeInformation( + IMethodSymbol method, + InteropAttributeData interopAttributeData, + IGeneratorDiagnostics diagnostics, + StubEnvironment env, + AttributeData signatureWideMarshallingAttributeData) { // Compute the current default string encoding value. CharEncoding defaultEncoding = CharEncoding.Undefined; @@ -120,14 +126,32 @@ private static ImmutableArray GenerateTypeInformation(IMethodS var defaultInfo = new DefaultMarshallingInfo(defaultEncoding, interopAttributeData.StringMarshallingCustomType); - var marshallingAttributeParser = new MarshallingAttributeInfoParser(env.Compilation, diagnostics, defaultInfo, method); + var useSiteAttributeParsers = ImmutableArray.Create( + new MarshalAsAttributeParser(env.Compilation, diagnostics, defaultInfo), + new MarshalUsingAttributeParser(env.Compilation, diagnostics)); + + var marshallingInfoParser = new MarshallingInfoParser( + diagnostics, + new MethodSignatureElementInfoProvider(env.Compilation, diagnostics, method, useSiteAttributeParsers), + useSiteAttributeParsers, + ImmutableArray.Create( + new MarshalAsAttributeParser(env.Compilation, diagnostics, defaultInfo), + new MarshalUsingAttributeParser(env.Compilation, diagnostics), + new NativeMarshallingAttributeParser(env.Compilation, diagnostics)), + ImmutableArray.Create( + new SafeHandleMarshallingInfoProvider(env.Compilation, method.ContainingType), + new ArrayMarshallingInfoProvider(env.Compilation), + new CharMarshallingInfoProvider(defaultInfo), + new StringMarshallingInfoProvider(env.Compilation, diagnostics, signatureWideMarshallingAttributeData, defaultInfo), + new BooleanMarshallingInfoProvider(), + new BlittableTypeMarshallingInfoProvider(env.Compilation))); // Determine parameter and return types ImmutableArray.Builder typeInfos = ImmutableArray.CreateBuilder(); for (int i = 0; i < method.Parameters.Length; i++) { IParameterSymbol param = method.Parameters[i]; - MarshallingInfo marshallingInfo = marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); + MarshallingInfo marshallingInfo = marshallingInfoParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); typeInfo = typeInfo with { @@ -137,7 +161,7 @@ private static ImmutableArray GenerateTypeInformation(IMethodS typeInfos.Add(typeInfo); } - TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); + TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingInfoParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..e8c9aee264d3dff73b5dd9307ae773a885189fd7 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + + /// + /// This class supports generating marshalling info for the type. + /// This includes support for the System.Runtime.InteropServices.StringMarshalling enum. + /// + public sealed class StringMarshallingInfoProvider : ITypeBasedMarshallingInfoProvider + { + private readonly Compilation _compilation; + private readonly IGeneratorDiagnostics _diagnostics; + private readonly AttributeData _stringMarshallingCustomAttribute; + private readonly DefaultMarshallingInfo _defaultMarshallingInfo; + + public StringMarshallingInfoProvider(Compilation compilation, IGeneratorDiagnostics diagnostics, AttributeData stringMarshallingCustomAttribute, DefaultMarshallingInfo defaultMarshallingInfo) + { + _compilation = compilation; + _diagnostics = diagnostics; + _stringMarshallingCustomAttribute = stringMarshallingCustomAttribute; + _defaultMarshallingInfo = defaultMarshallingInfo; + } + + public bool CanProvideMarshallingInfoForType(ITypeSymbol type) => type.SpecialType == SpecialType.System_String; + + public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + if (_defaultMarshallingInfo.CharEncoding == CharEncoding.Undefined) + { + return NoMarshallingInfo.Instance; + } + else if (_defaultMarshallingInfo.CharEncoding == CharEncoding.Custom) + { + if (_defaultMarshallingInfo.StringMarshallingCustomType is not null) + { + CountInfo countInfo = NoCountInfo.Instance; + if (useSiteAttributes.TryGetUseSiteAttributeInfo(indirectionDepth, out var useSiteInfo)) + { + countInfo = useSiteInfo.CountInfo; + } + return CustomMarshallingInfoHelper.CreateNativeMarshallingInfo( + type, + _defaultMarshallingInfo.StringMarshallingCustomType, + _stringMarshallingCustomAttribute, + useSiteAttributes, + marshallingInfoCallback, + indirectionDepth, + countInfo, + _diagnostics, + _compilation); + } + } + else + { + // No marshalling info was computed, but a character encoding was provided. + return _defaultMarshallingInfo.CharEncoding switch + { + CharEncoding.Utf16 => CreateStringMarshallingInfo(_compilation, type, TypeNames.Utf16StringMarshaller), + CharEncoding.Utf8 => CreateStringMarshallingInfo(_compilation, type, TypeNames.Utf8StringMarshaller), + _ => throw new InvalidOperationException() + }; + } + + return new MarshallingInfoStringSupport(_defaultMarshallingInfo.CharEncoding); + } + + public static MarshallingInfo CreateStringMarshallingInfo( + Compilation compilation, + ITypeSymbol type, + string marshallerName) + { + INamedTypeSymbol? stringMarshaller = compilation.GetTypeByMetadataName(marshallerName); + if (stringMarshaller is null) + return new MissingSupportMarshallingInfo(); + + if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(stringMarshaller)) + { + if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(stringMarshaller, type, compilation, out CustomTypeMarshallers? marshallers)) + { + return new NativeMarshallingAttributeInfo( + EntryPointType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(stringMarshaller), + Marshallers: marshallers.Value); + } + } + + return new MissingSupportMarshallingInfo(); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..fa19c1991e26278dc2872f14e3332f0338512d09 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UseSiteAttributeProvider.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + /// + /// Parses information for all use-site attributes and tracks their usage. + /// + public sealed class UseSiteAttributeProvider + { + private readonly ImmutableDictionary _useSiteAttributesByIndirectionDepth; + private readonly int _maxIndirectionLevelDataProvided; + private readonly IGeneratorDiagnostics _diagnostics; + private int _maxIndirectionLevelUsed; + + /// + /// Construct a new for a given usage site. + /// + /// The parsers for the attributes at the given usage site. + /// The attributes at the usage site. + /// The provider for additional element information, used by the attribute parsers. + /// Diagnostics sink for any invalid configurations. + /// A callback to get marshalling information for other elements. Used by . + internal UseSiteAttributeProvider( + ImmutableArray useSiteAttributeParsers, + IEnumerable useSiteAttributes, + IElementInfoProvider elementInfoProvider, + IGeneratorDiagnostics diagnostics, + GetMarshallingInfoCallback getMarshallingInfoCallback) + { + ImmutableDictionary.Builder useSiteAttributesByIndirectionDepth = ImmutableDictionary.CreateBuilder(); + _maxIndirectionLevelDataProvided = 0; + foreach (AttributeData attribute in useSiteAttributes) + { + UseSiteAttributeData? useSiteAttributeData = GetUseSiteInfoForAttribute(attribute); + if (useSiteAttributeData is not null) + { + int indirectionDepth = useSiteAttributeData.IndirectionDepth; + if (useSiteAttributesByIndirectionDepth.ContainsKey(indirectionDepth)) + { + diagnostics.ReportInvalidMarshallingAttributeInfo(attribute, nameof(SR.DuplicateMarshallingInfo), indirectionDepth.ToString()); + } + else + { + useSiteAttributesByIndirectionDepth.Add(indirectionDepth, useSiteAttributeData); + _maxIndirectionLevelDataProvided = Math.Max(_maxIndirectionLevelDataProvided, indirectionDepth); + } + } + } + _useSiteAttributesByIndirectionDepth = useSiteAttributesByIndirectionDepth.ToImmutable(); + _diagnostics = diagnostics; + + UseSiteAttributeData? GetUseSiteInfoForAttribute(AttributeData attribute) + { + foreach (var parser in useSiteAttributeParsers) + { + // Automatically ignore invalid attributes. + // The compiler will already error on them. + if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass)) + { + return parser.ParseAttribute(attribute, elementInfoProvider, getMarshallingInfoCallback); + } + } + return null; + } + } + + /// + /// Get the provided for a given , if it exists. + /// + /// The indirection depth to retrieve info for. + /// The use site information, if it exists. + /// true if an attribute was provided for the given indirection depth. + public bool TryGetUseSiteAttributeInfo(int indirectionDepth, out UseSiteAttributeData useSiteInfo) + { + _maxIndirectionLevelUsed = Math.Max(indirectionDepth, _maxIndirectionLevelUsed); + return _useSiteAttributesByIndirectionDepth.TryGetValue(indirectionDepth, out useSiteInfo); + } + + /// + /// Call when no more of the use-site attribute information will be used. + /// Records any information or diagnostics about unused marshalling information. + /// + internal void OnAttributeUsageFinished() + { + if (_maxIndirectionLevelUsed < _maxIndirectionLevelDataProvided) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo( + _useSiteAttributesByIndirectionDepth[_maxIndirectionLevelDataProvided].AttributeData, + nameof(SR.ExtraneousMarshallingInfo), + _maxIndirectionLevelDataProvided.ToString(), + _maxIndirectionLevelUsed.ToString()); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs index 5a77bb12d7b46632f45b125755633696ec826b4e..4d10887e3443321dc97947a59cc46dfbceb7a46a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs @@ -120,8 +120,8 @@ public static IEnumerable CodeSnippetsToCompile() // Generic collection marshaller has different arity than collection. yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.GenericCollectionMarshallingArityMismatch, 2, 0 }; - yield return new object[] { ID(), CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 2, 0 }; - yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingDuplicateElementIndirectionDepth, 2, 0 }; + yield return new object[] { ID(), CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 1, 0 }; + yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingDuplicateElementIndirectionDepth, 1, 0 }; yield return new object[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshallingUnusedElementIndirectionDepth, 1, 0 }; yield return new object[] { ID(), CodeSnippets.RecursiveCountElementNameOnReturnValue, 2, 0 }; yield return new object[] { ID(), CodeSnippets.RecursiveCountElementNameOnParameter, 2, 0 };