未验证 提交 058bed79 编写于 作者: M Marek Fišera 提交者: GitHub

[wasm] Unmanaged structs are considered blittable if module has...

[wasm] Unmanaged structs are considered blittable if module has `DisableRuntimeMarshallingAttribute` (#73310)

* [wasm] Unmanaged structs are considered blittable if module has DisableRuntimeMarshallingAttribute.
* Look for the DisableRuntimeMarshallingAttribute on the assembly where pinvoke method is declared.
- Cache DisableRuntimeMarshallingAttribute on assembly.
- Test without DisableRuntimeMarshallingAttribute.
* WasmBuildNative only without AOT.
* Unit test with struct in different assembly.

* Update, and add some additional tests
* [wasm] PInvokeTableGenerator: Avoid crash when processing unmanaged
callbacks with function pointers.
Co-authored-by: NAnkit Jain <radical@gmail.com>
上级 2b14842a
......@@ -14,6 +14,7 @@
internal sealed class PInvokeTableGenerator
{
private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
private readonly Dictionary<Assembly, bool> _assemblyDisableRuntimeMarshallingAttributeCache = new();
private TaskLoggingHelper Log { get; set; }
......@@ -45,7 +46,7 @@ public IEnumerable<string> Generate(string[] pinvokeModules, string[] assemblies
using (var w = File.CreateText(tmpFileName))
{
EmitPInvokeTable(w, modules, pinvokes);
EmitNativeToInterp(w, callbacks);
EmitNativeToInterp(w, ref callbacks);
}
if (Utils.CopyIfDifferent(tmpFileName, outputPath, useHash: false))
......@@ -68,6 +69,8 @@ private void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callb
try
{
CollectPInvokesForMethod(method);
if (DoesMethodHaveCallbacks(method))
callbacks.Add(new PInvokeCallback(method));
}
catch (Exception ex) when (ex is not LogAsErrorException)
{
......@@ -94,21 +97,57 @@ void CollectPInvokesForMethod(MethodInfo method)
Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
signatures.Add(signature);
}
}
bool DoesMethodHaveCallbacks(MethodInfo method)
{
if (!MethodHasCallbackAttributes(method))
return false;
if (TryIsMethodGetParametersUnsupported(method, out string? reason))
{
Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
$"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'.");
return false;
}
if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly))
return true;
// No DisableRuntimeMarshalling attribute, so check if the params/ret-type are
// blittable
bool isVoid = method.ReturnType.FullName == "System.Void";
if (!isVoid && !IsBlittable(method.ReturnType))
Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
foreach (var p in method.GetParameters())
{
if (!IsBlittable(p.ParameterType))
Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
}
return true;
}
static bool MethodHasCallbackAttributes(MethodInfo method)
{
foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
{
try
{
if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
callbacks.Add(new PInvokeCallback(method));
{
return true;
}
}
catch
{
// Assembly not found, ignore
}
}
return false;
}
}
......@@ -302,8 +341,10 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
if (TryIsMethodGetParametersUnsupported(pinvoke.Method, out string? reason))
{
// Don't use method.ToString() or any of it's parameters, or return type
// because at least one of those are unsupported, and will throw
Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
$"Skipping pinvoke '{pinvoke.Method.DeclaringType!.FullName}::{pinvoke.Method}' because '{reason}'.");
$"Skipping pinvoke '{pinvoke.Method.DeclaringType!.FullName}::{pinvoke.Method.Name}' because '{reason}'.");
pinvoke.Skip = true;
return null;
......@@ -324,7 +365,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
return sb.ToString();
}
private static void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
{
// Generate native->interp entry functions
// These are called by native code, so they need to obtain
......@@ -339,22 +380,7 @@ private static void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> cal
// Arguments to interp entry functions in the runtime
w.WriteLine("InterpFtnDesc wasm_native_to_interp_ftndescs[" + callbacks.Count + "];");
foreach (var cb in callbacks)
{
MethodInfo method = cb.Method;
bool isVoid = method.ReturnType.FullName == "System.Void";
if (!isVoid && !IsBlittable(method.ReturnType))
Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
foreach (var p in method.GetParameters())
{
if (!IsBlittable(p.ParameterType))
Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
}
}
var callbackNames = new HashSet<string>();
foreach (var cb in callbacks)
{
var sb = new StringBuilder();
......@@ -460,6 +486,18 @@ private static void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> cal
w.WriteLine("};");
}
private bool HasAssemblyDisableRuntimeMarshallingAttribute(Assembly assembly)
{
if (!_assemblyDisableRuntimeMarshallingAttributeCache.TryGetValue(assembly, out var value))
{
_assemblyDisableRuntimeMarshallingAttributeCache[assembly] = value = assembly
.GetCustomAttributesData()
.Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute");
}
return value;
}
private static bool IsBlittable(Type type)
{
if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum)
......
......@@ -221,7 +221,9 @@ public BuildTestBase(ITestOutputHelper output, SharedBuildPerTestClassFixture bu
// App arguments
if (envVars != null)
{
var setenv = string.Join(' ', envVars.Select(kvp => $"\"--setenv={kvp.Key}={kvp.Value}\"").ToArray());
var setenv = string.Join(' ', envVars
.Where(ev => ev.Key != "PATH")
.Select(kvp => $"\"--setenv={kvp.Key}={kvp.Value}\"").ToArray());
args.Append($" {setenv}");
}
......
......@@ -119,6 +119,195 @@ public static int Main()
Assert.Contains("Main running", output);
}
[Theory]
[BuildAndRun(host: RunHost.None)]
public void UnmanagedStructAndMethodIn_SameAssembly_WithoutDisableRuntimeMarshallingAttribute_NotConsideredBlittable
(BuildArgs buildArgs, string id)
{
(_, string output) = SingleProjectForDisabledRuntimeMarshallingTest(
withDisabledRuntimeMarshallingAttribute: false,
expectSuccess: false,
buildArgs,
id
);
Assert.Matches("error.*Parameter.*types.*pinvoke.*.*blittable", output);
}
[Theory]
[BuildAndRun(host: RunHost.Chrome)]
public void UnmanagedStructAndMethodIn_SameAssembly_WithDisableRuntimeMarshallingAttribute_ConsideredBlittable
(BuildArgs buildArgs, RunHost host, string id)
{
(buildArgs, _) = SingleProjectForDisabledRuntimeMarshallingTest(
withDisabledRuntimeMarshallingAttribute: true,
expectSuccess: true,
buildArgs,
id
);
string output = RunAndTestWasmApp(buildArgs, buildDir: _projectDir, expectedExitCode: 42, host: host, id: id);
Assert.Contains("Main running 5", output);
}
private (BuildArgs buildArgs ,string output) SingleProjectForDisabledRuntimeMarshallingTest(bool withDisabledRuntimeMarshallingAttribute, bool expectSuccess, BuildArgs buildArgs, string id)
{
string code =
"""
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
"""
+ (withDisabledRuntimeMarshallingAttribute ? "[assembly: DisableRuntimeMarshalling]" : "")
+ """
public class Test
{
public static int Main()
{
var x = new S { Value = 5 };
Console.WriteLine("Main running " + x.Value);
return 42;
}
public struct S { public int Value; }
[UnmanagedCallersOnly]
public static void M(S myStruct) { }
}
""";
buildArgs = ExpandBuildArgs(
buildArgs with { ProjectName = $"not_blittable_{buildArgs.Config}_{id}" },
extraProperties: buildArgs.AOT
? string.Empty
: "<WasmBuildNative>true</WasmBuildNative>"
);
(_, string output) = BuildProject(
buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () =>
{
File.WriteAllText(Path.Combine(_projectDir!, "Program.cs"), code);
},
Publish: buildArgs.AOT,
DotnetWasmFromRuntimePack: false,
ExpectSuccess: expectSuccess
)
);
return (buildArgs, output);
}
public static IEnumerable<object?[]> SeparateAssemblyWithDisableMarshallingAttributeTestData(string config)
=> ConfigWithAOTData(aot: false, config: config).Multiply(
new object[] { /*libraryHasAttribute*/ false, /*appHasAttribute*/ false, /*expectSuccess*/ false },
new object[] { /*libraryHasAttribute*/ true, /*appHasAttribute*/ false, /*expectSuccess*/ false },
new object[] { /*libraryHasAttribute*/ false, /*appHasAttribute*/ true, /*expectSuccess*/ true },
new object[] { /*libraryHasAttribute*/ true, /*appHasAttribute*/ true, /*expectSuccess*/ true }
).WithRunHosts(RunHost.Chrome).UnwrapItemsAsArrays();
[Theory]
[MemberData(nameof(SeparateAssemblyWithDisableMarshallingAttributeTestData), parameters: "Debug")]
[MemberData(nameof(SeparateAssemblyWithDisableMarshallingAttributeTestData), parameters: "Release")]
public void UnmanagedStructsAreConsideredBlittableFromDifferentAssembly
(BuildArgs buildArgs, bool libraryHasAttribute, bool appHasAttribute, bool expectSuccess, RunHost host, string id)
=> SeparateAssembliesForDisableRuntimeMarshallingTest(
libraryHasAttribute: libraryHasAttribute,
appHasAttribute: appHasAttribute,
expectSuccess: expectSuccess,
buildArgs,
host,
id
);
private void SeparateAssembliesForDisableRuntimeMarshallingTest
(bool libraryHasAttribute, bool appHasAttribute, bool expectSuccess, BuildArgs buildArgs, RunHost host, string id)
{
string code =
(libraryHasAttribute ? "[assembly: System.Runtime.CompilerServices.DisableRuntimeMarshalling]" : "")
+ "public struct S { public int Value; }";
var libraryBuildArgs = ExpandBuildArgs(
buildArgs with { ProjectName = $"blittable_different_library_{buildArgs.Config}_{id}" },
extraProperties: "<OutputType>Library</OutputType><RuntimeIdentifier />"
);
(string libraryDir, string output) = BuildProject(
libraryBuildArgs,
id: id + "_library",
new BuildProjectOptions(
InitProject: () =>
{
File.WriteAllText(Path.Combine(_projectDir!, "S.cs"), code);
},
Publish: buildArgs.AOT,
DotnetWasmFromRuntimePack: false,
AssertAppBundle: false
)
);
code =
"""
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
"""
+ (appHasAttribute ? "[assembly: DisableRuntimeMarshalling]" : "")
+ """
public class Test
{
public static int Main()
{
var x = new S { Value = 5 };
Console.WriteLine("Main running " + x.Value);
return 42;
}
[UnmanagedCallersOnly]
public static void M(S myStruct) { }
}
""";
buildArgs = ExpandBuildArgs(
buildArgs with { ProjectName = $"blittable_different_app_{buildArgs.Config}_{id}" },
extraItems: $@"<ProjectReference Include='{Path.Combine(libraryDir, libraryBuildArgs.ProjectName + ".csproj")}' />",
extraProperties: buildArgs.AOT
? string.Empty
: "<WasmBuildNative>true</WasmBuildNative>"
);
_projectDir = null;
(_, output) = BuildProject(
buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () =>
{
File.WriteAllText(Path.Combine(_projectDir!, "Program.cs"), code);
},
Publish: buildArgs.AOT,
DotnetWasmFromRuntimePack: false,
ExpectSuccess: expectSuccess
)
);
if (expectSuccess)
{
output = RunAndTestWasmApp(buildArgs, buildDir: _projectDir, expectedExitCode: 42, host: host, id: id);
Assert.Contains("Main running 5", output);
}
else
{
Assert.Matches("error.*Parameter.*types.*pinvoke.*.*blittable", output);
}
}
[Theory]
[BuildAndRun(host: RunHost.Chrome)]
public void DllImportWithFunctionPointers_WarningsAsMessages(BuildArgs buildArgs, RunHost host, string id)
......@@ -154,6 +343,36 @@ public static int Main()
Assert.Contains("Main running", output);
}
[Theory]
[BuildAndRun(host: RunHost.None)]
public void UnmanagedCallback_WithFunctionPointers_CompilesWithWarnings(BuildArgs buildArgs, string id)
{
string code =
"""
using System;
using System.Runtime.InteropServices;
public class Test
{
public static int Main()
{
Console.WriteLine("Main running");
return 42;
}
[UnmanagedCallersOnly]
public unsafe static extern void SomeFunction1(delegate* unmanaged<int> callback);
}
""";
(_, string output) = BuildForVariadicFunctionTests(
code,
buildArgs with { ProjectName = $"cb_fnptr_{buildArgs.Config}" },
id
);
Assert.Matches("warning\\sWASM0001.*Skipping.*Test::SomeFunction1.*because.*function\\spointer", output);
}
[ConditionalTheory(typeof(BuildTestBase), nameof(IsUsingWorkloads))]
[BuildAndRun(host: RunHost.None)]
public void IcallWithOverloadedParametersAndEnum(BuildArgs buildArgs, string id)
......@@ -239,7 +458,7 @@ public static void Main()
,{ "name": "Add(Numbers,Numbers)", "func": "ves_def", "handles": false }
]}
]
""";
projectCode = projectCode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册