PInvokeTableGenerator.cs 18.3 KB
Newer Older
Z
Zoltan Varga 已提交
1 2 3 4 5
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
6
using System.Diagnostics.CodeAnalysis;
Z
Zoltan Varga 已提交
7 8 9 10 11 12 13
using System.IO;
using System.Linq;
using System.Text;
using System.Reflection;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;

14
internal sealed class PInvokeTableGenerator
Z
Zoltan Varga 已提交
15
{
16
    private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
17

18
    private TaskLoggingHelper Log { get; set; }
Z
Zoltan Varga 已提交
19

20
    public PInvokeTableGenerator(TaskLoggingHelper log) => Log = log;
21

22
    public IEnumerable<string> Generate(string[] pinvokeModules, string[] assemblies, string outputPath)
23
    {
24
        var modules = new Dictionary<string, string>();
Z
Zoltan Varga 已提交
25
        foreach (var module in pinvokeModules)
26 27 28
            modules[module] = module;

        var signatures = new List<string>();
Z
Zoltan Varga 已提交
29

30
        var pinvokes = new List<PInvoke>();
31
        var callbacks = new List<PInvokeCallback>();
Z
Zoltan Varga 已提交
32

33
        var resolver = new PathAssemblyResolver(assemblies);
34
        using var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib");
35 36 37 38
        foreach (var aname in assemblies)
        {
            var a = mlc.LoadFromAssemblyPath(aname);
            foreach (var type in a.GetTypes())
39
                CollectPInvokes(pinvokes, callbacks, signatures, type);
Z
Zoltan Varga 已提交
40 41
        }

42
        string tmpFileName = Path.GetTempFileName();
43
        try
44
        {
45 46 47 48 49
            using (var w = File.CreateText(tmpFileName))
            {
                EmitPInvokeTable(w, modules, pinvokes);
                EmitNativeToInterp(w, callbacks);
            }
50

51 52 53 54 55 56 57 58 59
            if (Utils.CopyIfDifferent(tmpFileName, outputPath, useHash: false))
                Log.LogMessage(MessageImportance.Low, $"Generating pinvoke table to '{outputPath}'.");
            else
                Log.LogMessage(MessageImportance.Low, $"PInvoke table in {outputPath} is unchanged.");
        }
        finally
        {
            File.Delete(tmpFileName);
        }
60

61
        return signatures;
Z
Zoltan Varga 已提交
62 63
    }

64
    private void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callbacks, List<string> signatures, Type type)
65
    {
66
        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
67 68 69 70 71
        {
            try
            {
                CollectPInvokesForMethod(method);
            }
72
            catch (Exception ex) when (ex is not LogAsErrorException)
73
            {
74 75
                Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
                        $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
76 77 78 79 80
            }
        }

        void CollectPInvokesForMethod(MethodInfo method)
        {
81 82 83 84 85 86
            if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
            {
                var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
                var module = (string)dllimport.ConstructorArguments[0].Value!;
                var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
                pinvokes.Add(new PInvoke(entrypoint, module, method));
87 88 89 90

                string? signature = SignatureMapper.MethodToSignature(method);
                if (signature == null)
                {
91
                    throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
92 93
                }

94
                Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
95
                signatures.Add(signature);
96 97 98 99 100 101 102 103
            }

            foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
            {
                try
                {
                    if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
                        cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
104

105 106 107 108 109 110 111
                        callbacks.Add(new PInvokeCallback(method));
                }
                catch
                {
                    // Assembly not found, ignore
                }
            }
Z
Zoltan Varga 已提交
112 113 114
        }
    }

115 116 117 118 119
    private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules, List<PInvoke> pinvokes)
    {
        w.WriteLine("// GENERATED FILE, DO NOT MODIFY");
        w.WriteLine();

120 121 122 123 124 125 126
        var pinvokesGroupedByEntryPoint = pinvokes
                                            .Where(l => modules.ContainsKey(l.Module))
                                            .OrderBy(l => l.EntryPoint)
                                            .GroupBy(l => l.EntryPoint);

        var comparer = new PInvokeComparer();
        foreach (IGrouping<string, PInvoke> group in pinvokesGroupedByEntryPoint)
127
        {
128 129 130 131 132 133 134 135 136 137
            var candidates = group.Distinct(comparer).ToArray();
            PInvoke first = candidates[0];
            if (ShouldTreatAsVariadic(candidates))
            {
                string imports = string.Join(Environment.NewLine,
                                            candidates.Select(
                                                p => $"    {p.Method} (in [{p.Method.DeclaringType?.Assembly.GetName().Name}] {p.Method.DeclaringType})"));
                Log.LogWarning($"Found a native function ({first.EntryPoint}) with varargs in {first.Module}." +
                                 " Calling such functions is not supported, and will fail at runtime." +
                                $" Managed DllImports: {Environment.NewLine}{imports}");
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
                foreach (var c in candidates)
                    c.Skip = true;

                continue;
            }

            var decls = new HashSet<string>();
            foreach (var candidate in candidates)
            {
                var decl = GenPInvokeDecl(candidate);
                if (decl == null || decls.Contains(decl))
                    continue;

                w.WriteLine(decl);
                decls.Add(decl);
154
            }
Z
Zoltan Varga 已提交
155 156
        }

157 158
        foreach (var module in modules.Keys)
        {
159
            string symbol = ModuleNameToId(module) + "_imports";
160
            w.WriteLine("static PinvokeImport " + symbol + " [] = {");
161 162

            var assemblies_pinvokes = pinvokes.
163
                Where(l => l.Module == module && !l.Skip).
164 165
                OrderBy(l => l.EntryPoint).
                GroupBy(d => d.EntryPoint).
166 167
                Select(l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " +
                                "// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n)));
168

169 170 171
            foreach (var pinvoke in assemblies_pinvokes)
            {
                w.WriteLine(pinvoke);
Z
Zoltan Varga 已提交
172
            }
173

174 175
            w.WriteLine("{NULL, NULL}");
            w.WriteLine("};");
Z
Zoltan Varga 已提交
176
        }
177
        w.Write("static void *pinvoke_tables[] = { ");
178 179
        foreach (var module in modules.Keys)
        {
180
            string symbol = ModuleNameToId(module) + "_imports";
181
            w.Write(symbol + ",");
Z
Zoltan Varga 已提交
182
        }
183
        w.WriteLine("};");
184
        w.Write("static char *pinvoke_names[] = { ");
185 186 187
        foreach (var module in modules.Keys)
        {
            w.Write("\"" + module + "\"" + ",");
Z
Zoltan Varga 已提交
188
        }
189
        w.WriteLine("};");
190 191 192 193 194 195 196 197 198 199 200 201

        static string ModuleNameToId(string name)
        {
            if (name.IndexOfAny(s_charsToReplace) < 0)
                return name;

            string fixedName = name;
            foreach (char c in s_charsToReplace)
                fixedName = fixedName.Replace(c, '_');

            return fixedName;
        }
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

        static bool ShouldTreatAsVariadic(PInvoke[] candidates)
        {
            if (candidates.Length < 2)
                return false;

            PInvoke first = candidates[0];
            if (TryIsMethodGetParametersUnsupported(first.Method, out _))
                return false;

            int firstNumArgs = first.Method.GetParameters().Length;
            return candidates
                        .Skip(1)
                        .Any(c => !TryIsMethodGetParametersUnsupported(c.Method, out _) &&
                                    c.Method.GetParameters().Length != firstNumArgs);
        }
Z
Zoltan Varga 已提交
218 219
    }

220 221 222 223 224 225 226 227 228 229 230 231 232
    private static string FixupSymbolName(string name)
    {
        UTF8Encoding utf8 = new();
        byte[] bytes = utf8.GetBytes(name);
        StringBuilder sb = new();

        foreach (byte b in bytes)
        {
            if ((b >= (byte)'0' && b <= (byte)'9') ||
                (b >= (byte)'a' && b <= (byte)'z') ||
                (b >= (byte)'A' && b <= (byte)'Z') ||
                (b == (byte)'_'))
            {
233
                sb.Append((char)b);
234
            }
235
            else if (s_charsToReplace.Contains((char)b))
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
            {
                sb.Append('_');
            }
            else
            {
                sb.Append($"_{b:X}_");
            }
        }

        return sb.ToString();
    }

    private static string SymbolNameForMethod(MethodInfo method)
    {
        StringBuilder sb = new();
        Type? type = method.DeclaringType;
        sb.Append($"{type!.Module!.Assembly!.GetName()!.Name!}_");
        sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_");
        sb.Append(method.Name);

        return FixupSymbolName(sb.ToString());
    }

259
    private static string MapType(Type t) => t.Name switch
260
    {
261 262 263 264 265 266 267
        "Void" => "void",
        nameof(Double) => "double",
        nameof(Single) => "float",
        nameof(Int64) => "int64_t",
        nameof(UInt64) => "uint64_t",
        _ => "int"
    };
Z
Zoltan Varga 已提交
268

269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
    // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
    // https://github.com/dotnet/runtime/issues/43791
    private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotNullWhen(true)] out string? reason)
    {
        try
        {
            method.GetParameters();
        }
        catch (NotSupportedException nse)
        {
            reason = nse.Message;
            return true;
        }
        catch
        {
            // not concerned with other exceptions
        }

        reason = null;
        return false;
    }

    private string? GenPInvokeDecl(PInvoke pinvoke)
292 293
    {
        var sb = new StringBuilder();
Z
Zoltan Varga 已提交
294
        var method = pinvoke.Method;
295 296
        if (method.Name == "EnumCalendarInfo")
        {
297 298
            // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
            // https://github.com/dotnet/runtime/issues/43791
299
            sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
300 301
            return sb.ToString();
        }
302 303 304

        if (TryIsMethodGetParametersUnsupported(pinvoke.Method, out string? reason))
        {
305 306
            Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
                    $"Skipping pinvoke '{pinvoke.Method.DeclaringType!.FullName}::{pinvoke.Method}' because '{reason}'.");
307

308 309 310 311
            pinvoke.Skip = true;
            return null;
        }

312
        sb.Append(MapType(method.ReturnType));
313
        sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} (");
Z
Zoltan Varga 已提交
314
        int pindex = 0;
315
        var pars = method.GetParameters();
316 317
        foreach (var p in pars)
        {
Z
Zoltan Varga 已提交
318
            if (pindex > 0)
319
                sb.Append(',');
320
            sb.Append(MapType(pars[pindex].ParameterType));
321
            pindex++;
Z
Zoltan Varga 已提交
322
        }
323 324
        sb.Append(");");
        return sb.ToString();
Z
Zoltan Varga 已提交
325
    }
326

327
    private static void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
328
    {
329 330 331 332 333 334 335 336 337 338 339 340 341
        // Generate native->interp entry functions
        // These are called by native code, so they need to obtain
        // the interp entry function/arg from a global array
        // They also need to have a signature matching what the
        // native code expects, which is the native signature
        // of the delegate invoke in the [MonoPInvokeCallback]
        // attribute.
        // Only blittable parameter/return types are supposed.
        int cb_index = 0;

        // Arguments to interp entry functions in the runtime
        w.WriteLine("InterpFtnDesc wasm_native_to_interp_ftndescs[" + callbacks.Count + "];");

342 343
        foreach (var cb in callbacks)
        {
344 345
            MethodInfo method = cb.Method;
            bool isVoid = method.ReturnType.FullName == "System.Void";
346

347 348
            if (!isVoid && !IsBlittable(method.ReturnType))
                Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
349 350
            foreach (var p in method.GetParameters())
            {
351 352 353 354 355 356 357
                if (!IsBlittable(p.ParameterType))
                    Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
            }
        }

        var callbackNames = new HashSet<string>();

358 359
        foreach (var cb in callbacks)
        {
360 361 362 363 364 365
            var sb = new StringBuilder();
            var method = cb.Method;

            // The signature of the interp entry function
            // This is a gsharedvt_in signature
            sb.Append("typedef void ");
366
            sb.Append($" (*WasmInterpEntrySig_{cb_index}) (");
367
            int pindex = 0;
368 369
            if (method.ReturnType.Name != "Void")
            {
370
                sb.Append("int*");
371
                pindex++;
372
            }
373 374
            foreach (var p in method.GetParameters())
            {
375
                if (pindex > 0)
376
                    sb.Append(',');
377
                sb.Append("int*");
378
                pindex++;
379 380
            }
            if (pindex > 0)
381
                sb.Append(',');
382
            // Extra arg
383
            sb.Append("int*");
384 385 386 387 388 389 390 391 392
            sb.Append(");\n");

            bool is_void = method.ReturnType.Name == "Void";

            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
            uint token = (uint)method.MetadataToken;
            string class_name = method.DeclaringType.Name;
            string method_name = method.Name;
            string entry_name = $"wasm_native_to_interp_{module_symbol}_{class_name}_{method_name}";
393
            if (callbackNames.Contains(entry_name))
394 395 396
            {
                Error($"Two callbacks with the same name '{method_name}' are not supported.");
            }
397
            callbackNames.Add(entry_name);
398
            cb.EntryName = entry_name;
399 400 401
            sb.Append(MapType(method.ReturnType));
            sb.Append($" {entry_name} (");
            pindex = 0;
402 403
            foreach (var p in method.GetParameters())
            {
404
                if (pindex > 0)
405
                    sb.Append(',');
406
                sb.Append(MapType(method.GetParameters()[pindex].ParameterType));
407
                sb.Append($" arg{pindex}");
408
                pindex++;
409 410 411 412
            }
            sb.Append(") { \n");
            if (!is_void)
                sb.Append(MapType(method.ReturnType) + " res;\n");
413
            sb.Append($"((WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) (");
414
            pindex = 0;
415 416
            if (!is_void)
            {
417
                sb.Append("&res");
418
                pindex++;
419 420
            }
            int aindex = 0;
421 422
            foreach (var p in method.GetParameters())
            {
423 424
                if (pindex > 0)
                    sb.Append(", ");
425
                sb.Append($"&arg{aindex}");
426 427
                pindex++;
                aindex++;
428 429 430 431 432 433 434
            }
            if (pindex > 0)
                sb.Append(", ");
            sb.Append($"wasm_native_to_interp_ftndescs [{cb_index}].arg");
            sb.Append(");\n");
            if (!is_void)
                sb.Append("return res;\n");
435
            sb.Append('}');
436
            w.WriteLine(sb);
437
            cb_index++;
438 439 440
        }

        // Array of function pointers
441 442 443 444
        w.Write("static void *wasm_native_to_interp_funcs[] = { ");
        foreach (var cb in callbacks)
        {
            w.Write(cb.EntryName + ",");
445
        }
446
        w.WriteLine("};");
447 448 449 450

        // Lookup table from method->interp entry
        // The key is a string of the form <assembly name>_<method token>
        // FIXME: Use a better encoding
451 452 453
        w.Write("static const char *wasm_native_to_interp_map[] = { ");
        foreach (var cb in callbacks)
        {
454 455 456 457
            var method = cb.Method;
            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
            string class_name = method.DeclaringType.Name;
            string method_name = method.Name;
458
            w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\",");
459
        }
460
        w.WriteLine("};");
461
    }
462

463
    private static bool IsBlittable(Type type)
464
    {
465
        if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum)
466 467 468 469 470
            return true;
        else
            return false;
    }

471
    private static void Error(string msg) => throw new LogAsErrorException(msg);
Z
Zoltan Varga 已提交
472 473
}

474
#pragma warning disable CA1067
475
internal sealed class PInvoke : IEquatable<PInvoke>
476
#pragma warning restore CA1067
Z
Zoltan Varga 已提交
477
{
478 479 480
    public PInvoke(string entryPoint, string module, MethodInfo method)
    {
        EntryPoint = entryPoint;
Z
Zoltan Varga 已提交
481 482 483 484 485 486 487
        Module = module;
        Method = method;
    }

    public string EntryPoint;
    public string Module;
    public MethodInfo Method;
488
    public bool Skip;
489 490 491 492 493 494 495 496 497 498

    public bool Equals(PInvoke? other)
        => other != null &&
            string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
            string.Equals(Module, other.Module, StringComparison.Ordinal) &&
            string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);

    public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
}

499
internal sealed class PInvokeComparer : IEqualityComparer<PInvoke>
500 501 502 503 504 505 506 507 508 509 510 511 512
{
    public bool Equals(PInvoke? x, PInvoke? y)
    {
        if (x == null && y == null)
            return true;
        if (x == null || y == null)
            return false;

        return x.Equals(y);
    }

    public int GetHashCode(PInvoke pinvoke)
        => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
Z
Zoltan Varga 已提交
513
}
514

515
internal sealed class PInvokeCallback
516 517 518 519 520 521 522 523 524
{
    public PInvokeCallback(MethodInfo method)
    {
        Method = method;
    }

    public MethodInfo Method;
    public string? EntryName;
}