PInvokeTableGenerator.cs 18.3 KB
Newer Older
Z
Zoltan Varga 已提交
1 2 3 4 5
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
6
using System.Diagnostics.CodeAnalysis;
Z
Zoltan Varga 已提交
7 8 9 10 11 12 13 14 15
using System.IO;
using System.Linq;
using System.Text;
using System.Reflection;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;

public class PInvokeTableGenerator : Task
{
16 17 18 19
    [Required, NotNull]
    public string[]? Modules { get; set; }
    [Required, NotNull]
    public string[]? Assemblies { get; set; }
20 21

    [Required, NotNull]
Z
Zoltan Varga 已提交
22 23
    public string? OutputPath { get; set; }

24 25 26
    [Output]
    public string FileWrites { get; private set; } = string.Empty;

27
    private static char[] s_charsToReplace = new[] { '.', '-', '+' };
28

29 30
    public override bool Execute()
    {
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
        if (Assemblies.Length == 0)
        {
            Log.LogError($"No assemblies given to scan for pinvokes");
            return false;
        }

        if (Modules.Length == 0)
        {
            Log.LogError($"{nameof(PInvokeTableGenerator)}.{nameof(Modules)} cannot be empty");
            return false;
        }

        try
        {
            GenPInvokeTable(Modules, Assemblies);
            return !Log.HasLoggedErrors;
        }
        catch (LogAsErrorException laee)
        {
            Log.LogError(laee.Message);
            return false;
        }
Z
Zoltan Varga 已提交
53 54
    }

55
    public void GenPInvokeTable(string[] pinvokeModules, string[] assemblies)
56
    {
57
        var modules = new Dictionary<string, string>();
Z
Zoltan Varga 已提交
58 59 60
        foreach (var module in pinvokeModules)
            modules [module] = module;

61
        var pinvokes = new List<PInvoke>();
62
        var callbacks = new List<PInvokeCallback>();
Z
Zoltan Varga 已提交
63

64
        var resolver = new PathAssemblyResolver(assemblies);
65
        using var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib");
66 67 68 69
        foreach (var aname in assemblies)
        {
            var a = mlc.LoadFromAssemblyPath(aname);
            foreach (var type in a.GetTypes())
70
                CollectPInvokes(pinvokes, callbacks, type);
Z
Zoltan Varga 已提交
71 72
        }

73 74
        string tmpFileName = Path.GetTempFileName();
        using (var w = File.CreateText(tmpFileName))
75 76
        {
            EmitPInvokeTable(w, modules, pinvokes);
77
            EmitNativeToInterp(w, callbacks);
Z
Zoltan Varga 已提交
78
        }
79 80 81 82 83

        if (Utils.CopyIfDifferent(tmpFileName, OutputPath, useHash: false))
            Log.LogMessage(MessageImportance.Low, $"Generating pinvoke table to '{OutputPath}'.");
        else
            Log.LogMessage(MessageImportance.Low, $"PInvoke table in {OutputPath} is unchanged.");
84
        FileWrites = OutputPath;
85 86

        File.Delete(tmpFileName);
Z
Zoltan Varga 已提交
87 88
    }

89
    private void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callbacks, Type type)
90
    {
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly|BindingFlags.Public|BindingFlags.NonPublic|BindingFlags.Static|BindingFlags.Instance))
        {
            try
            {
                CollectPInvokesForMethod(method);
            }
            catch (Exception ex)
            {
                Log.LogMessage(MessageImportance.Low, $"Could not get pinvoke, or callbacks for method {method.Name}: {ex}");
                continue;
            }
        }

        void CollectPInvokesForMethod(MethodInfo method)
        {
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
            {
                var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
                var module = (string)dllimport.ConstructorArguments[0].Value!;
                var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
                pinvokes.Add(new PInvoke(entrypoint, module, method));
            }

            foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
            {
                try
                {
                    if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
                        cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
                        callbacks.Add(new PInvokeCallback(method));
                }
                catch
                {
                    // Assembly not found, ignore
                }
            }
Z
Zoltan Varga 已提交
127 128 129
        }
    }

130 131 132 133 134
    private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules, List<PInvoke> pinvokes)
    {
        w.WriteLine("// GENERATED FILE, DO NOT MODIFY");
        w.WriteLine();

135 136 137 138 139 140 141
        var pinvokesGroupedByEntryPoint = pinvokes
                                            .Where(l => modules.ContainsKey(l.Module))
                                            .OrderBy(l => l.EntryPoint)
                                            .GroupBy(l => l.EntryPoint);

        var comparer = new PInvokeComparer();
        foreach (IGrouping<string, PInvoke> group in pinvokesGroupedByEntryPoint)
142
        {
143 144 145 146 147 148 149 150 151 152
            var candidates = group.Distinct(comparer).ToArray();
            PInvoke first = candidates[0];
            if (ShouldTreatAsVariadic(candidates))
            {
                string imports = string.Join(Environment.NewLine,
                                            candidates.Select(
                                                p => $"    {p.Method} (in [{p.Method.DeclaringType?.Assembly.GetName().Name}] {p.Method.DeclaringType})"));
                Log.LogWarning($"Found a native function ({first.EntryPoint}) with varargs in {first.Module}." +
                                 " Calling such functions is not supported, and will fail at runtime." +
                                $" Managed DllImports: {Environment.NewLine}{imports}");
153

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
                foreach (var c in candidates)
                    c.Skip = true;

                continue;
            }

            var decls = new HashSet<string>();
            foreach (var candidate in candidates)
            {
                var decl = GenPInvokeDecl(candidate);
                if (decl == null || decls.Contains(decl))
                    continue;

                w.WriteLine(decl);
                decls.Add(decl);
169
            }
Z
Zoltan Varga 已提交
170 171
        }

172 173
        foreach (var module in modules.Keys)
        {
174
            string symbol = ModuleNameToId(module) + "_imports";
175
            w.WriteLine("static PinvokeImport " + symbol + " [] = {");
176 177

            var assemblies_pinvokes = pinvokes.
178
                Where(l => l.Module == module && !l.Skip).
179 180
                OrderBy(l => l.EntryPoint).
                GroupBy(d => d.EntryPoint).
181 182
                Select (l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " +
                                "// " + string.Join (", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName ()!.Name!).Distinct().OrderBy(n => n)));
183 184 185

            foreach (var pinvoke in assemblies_pinvokes) {
                w.WriteLine (pinvoke);
Z
Zoltan Varga 已提交
186
            }
187

188 189
            w.WriteLine("{NULL, NULL}");
            w.WriteLine("};");
Z
Zoltan Varga 已提交
190
        }
191
        w.Write("static void *pinvoke_tables[] = { ");
192 193
        foreach (var module in modules.Keys)
        {
194
            string symbol = ModuleNameToId(module) + "_imports";
195
            w.Write(symbol + ",");
Z
Zoltan Varga 已提交
196
        }
197
        w.WriteLine("};");
198
        w.Write("static char *pinvoke_names[] = { ");
199 200 201
        foreach (var module in modules.Keys)
        {
            w.Write("\"" + module + "\"" + ",");
Z
Zoltan Varga 已提交
202
        }
203
        w.WriteLine("};");
204 205 206 207 208 209 210 211 212 213 214 215

        static string ModuleNameToId(string name)
        {
            if (name.IndexOfAny(s_charsToReplace) < 0)
                return name;

            string fixedName = name;
            foreach (char c in s_charsToReplace)
                fixedName = fixedName.Replace(c, '_');

            return fixedName;
        }
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

        static bool ShouldTreatAsVariadic(PInvoke[] candidates)
        {
            if (candidates.Length < 2)
                return false;

            PInvoke first = candidates[0];
            if (TryIsMethodGetParametersUnsupported(first.Method, out _))
                return false;

            int firstNumArgs = first.Method.GetParameters().Length;
            return candidates
                        .Skip(1)
                        .Any(c => !TryIsMethodGetParametersUnsupported(c.Method, out _) &&
                                    c.Method.GetParameters().Length != firstNumArgs);
        }
Z
Zoltan Varga 已提交
232 233
    }

234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
    private static string FixupSymbolName(string name)
    {
        UTF8Encoding utf8 = new();
        byte[] bytes = utf8.GetBytes(name);
        StringBuilder sb = new();

        foreach (byte b in bytes)
        {
            if ((b >= (byte)'0' && b <= (byte)'9') ||
                (b >= (byte)'a' && b <= (byte)'z') ||
                (b >= (byte)'A' && b <= (byte)'Z') ||
                (b == (byte)'_'))
            {
                sb.Append((char) b);
            }
            else if (s_charsToReplace.Contains((char) b))
            {
                sb.Append('_');
            }
            else
            {
                sb.Append($"_{b:X}_");
            }
        }

        return sb.ToString();
    }

    private static string SymbolNameForMethod(MethodInfo method)
    {
        StringBuilder sb = new();
        Type? type = method.DeclaringType;
        sb.Append($"{type!.Module!.Assembly!.GetName()!.Name!}_");
        sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_");
        sb.Append(method.Name);

        return FixupSymbolName(sb.ToString());
    }

273
    private static string MapType (Type t)
274
    {
Z
Zoltan Varga 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
        string name = t.Name;
        if (name == "Void")
            return "void";
        else if (name == "Double")
            return "double";
        else if (name == "Single")
            return "float";
        else if (name == "Int64")
            return "int64_t";
        else if (name == "UInt64")
            return "uint64_t";
        else
            return "int";
    }

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
    // https://github.com/dotnet/runtime/issues/43791
    private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotNullWhen(true)] out string? reason)
    {
        try
        {
            method.GetParameters();
        }
        catch (NotSupportedException nse)
        {
            reason = nse.Message;
            return true;
        }
        catch
        {
            // not concerned with other exceptions
        }

        reason = null;
        return false;
    }

    private string? GenPInvokeDecl(PInvoke pinvoke)
313 314
    {
        var sb = new StringBuilder();
Z
Zoltan Varga 已提交
315
        var method = pinvoke.Method;
316 317 318
        if (method.Name == "EnumCalendarInfo") {
            // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
            // https://github.com/dotnet/runtime/issues/43791
319
            sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
320 321
            return sb.ToString();
        }
322 323 324 325 326 327 328 329

        if (TryIsMethodGetParametersUnsupported(pinvoke.Method, out string? reason))
        {
            Log.LogWarning($"Skipping the following DllImport because '{reason}'. {Environment.NewLine}  {pinvoke.Method}");
            pinvoke.Skip = true;
            return null;
        }

330
        sb.Append(MapType(method.ReturnType));
331
        sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} (");
Z
Zoltan Varga 已提交
332
        int pindex = 0;
333
        var pars = method.GetParameters();
Z
Zoltan Varga 已提交
334 335
        foreach (var p in pars) {
            if (pindex > 0)
336
                sb.Append(',');
337
            sb.Append(MapType(pars[pindex].ParameterType));
338
            pindex++;
Z
Zoltan Varga 已提交
339
        }
340 341
        sb.Append(");");
        return sb.ToString();
Z
Zoltan Varga 已提交
342
    }
343

344
    private static void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
345
    {
346 347 348 349 350 351 352 353 354 355 356 357 358 359
        // Generate native->interp entry functions
        // These are called by native code, so they need to obtain
        // the interp entry function/arg from a global array
        // They also need to have a signature matching what the
        // native code expects, which is the native signature
        // of the delegate invoke in the [MonoPInvokeCallback]
        // attribute.
        // Only blittable parameter/return types are supposed.
        int cb_index = 0;

        // Arguments to interp entry functions in the runtime
        w.WriteLine("InterpFtnDesc wasm_native_to_interp_ftndescs[" + callbacks.Count + "];");

        foreach (var cb in callbacks) {
360 361
            MethodInfo method = cb.Method;
            bool isVoid = method.ReturnType.FullName == "System.Void";
362

363 364
            if (!isVoid && !IsBlittable(method.ReturnType))
                Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
            foreach (var p in method.GetParameters()) {
                if (!IsBlittable(p.ParameterType))
                    Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
            }
        }

        var callbackNames = new HashSet<string>();

        foreach (var cb in callbacks) {
            var sb = new StringBuilder();
            var method = cb.Method;

            // The signature of the interp entry function
            // This is a gsharedvt_in signature
            sb.Append("typedef void ");
380
            sb.Append($" (*WasmInterpEntrySig_{cb_index}) (");
381 382
            int pindex = 0;
            if (method.ReturnType.Name != "Void") {
383
                sb.Append("int*");
384
                pindex++;
385 386 387
            }
            foreach (var p in method.GetParameters()) {
                if (pindex > 0)
388
                    sb.Append(',');
389
                sb.Append("int*");
390
                pindex++;
391 392
            }
            if (pindex > 0)
393
                sb.Append(',');
394
            // Extra arg
395
            sb.Append("int*");
396 397 398 399 400 401 402 403 404 405 406 407 408 409
            sb.Append(");\n");

            bool is_void = method.ReturnType.Name == "Void";

            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
            uint token = (uint)method.MetadataToken;
            string class_name = method.DeclaringType.Name;
            string method_name = method.Name;
            string entry_name = $"wasm_native_to_interp_{module_symbol}_{class_name}_{method_name}";
            if (callbackNames.Contains (entry_name))
            {
                Error($"Two callbacks with the same name '{method_name}' are not supported.");
            }
            callbackNames.Add (entry_name);
410
            cb.EntryName = entry_name;
411 412 413 414 415
            sb.Append(MapType(method.ReturnType));
            sb.Append($" {entry_name} (");
            pindex = 0;
            foreach (var p in method.GetParameters()) {
                if (pindex > 0)
416
                    sb.Append(',');
417
                sb.Append(MapType(method.GetParameters()[pindex].ParameterType));
418
                sb.Append($" arg{pindex}");
419
                pindex++;
420 421 422 423
            }
            sb.Append(") { \n");
            if (!is_void)
                sb.Append(MapType(method.ReturnType) + " res;\n");
424
            sb.Append($"((WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) (");
425 426 427
            pindex = 0;
            if (!is_void) {
                sb.Append("&res");
428
                pindex++;
429 430 431 432 433
            }
            int aindex = 0;
            foreach (var p in method.GetParameters()) {
                if (pindex > 0)
                    sb.Append(", ");
434
                sb.Append($"&arg{aindex}");
435 436
                pindex++;
                aindex++;
437 438 439 440 441 442 443
            }
            if (pindex > 0)
                sb.Append(", ");
            sb.Append($"wasm_native_to_interp_ftndescs [{cb_index}].arg");
            sb.Append(");\n");
            if (!is_void)
                sb.Append("return res;\n");
444
            sb.Append('}');
445
            w.WriteLine(sb);
446
            cb_index++;
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
        }

        // Array of function pointers
        w.Write ("static void *wasm_native_to_interp_funcs[] = { ");
        foreach (var cb in callbacks) {
            w.Write (cb.EntryName + ",");
        }
        w.WriteLine ("};");

        // Lookup table from method->interp entry
        // The key is a string of the form <assembly name>_<method token>
        // FIXME: Use a better encoding
        w.Write ("static const char *wasm_native_to_interp_map[] = { ");
        foreach (var cb in callbacks) {
            var method = cb.Method;
            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
            string class_name = method.DeclaringType.Name;
            string method_name = method.Name;
            w.WriteLine ($"\"{module_symbol}_{class_name}_{method_name}\",");
        }
        w.WriteLine ("};");
    }
469

470
    private static bool IsBlittable (Type type)
471
    {
472
        if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum)
473 474 475 476 477
            return true;
        else
            return false;
    }

478
    private static void Error (string msg) => throw new LogAsErrorException(msg);
Z
Zoltan Varga 已提交
479 480
}

481
#pragma warning disable CA1067
482
internal sealed class PInvoke : IEquatable<PInvoke>
483
#pragma warning restore CA1067
Z
Zoltan Varga 已提交
484
{
485 486 487
    public PInvoke(string entryPoint, string module, MethodInfo method)
    {
        EntryPoint = entryPoint;
Z
Zoltan Varga 已提交
488 489 490 491 492 493 494
        Module = module;
        Method = method;
    }

    public string EntryPoint;
    public string Module;
    public MethodInfo Method;
495
    public bool Skip;
496 497 498 499 500 501 502 503 504 505

    public bool Equals(PInvoke? other)
        => other != null &&
            string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
            string.Equals(Module, other.Module, StringComparison.Ordinal) &&
            string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);

    public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
}

506
internal sealed class PInvokeComparer : IEqualityComparer<PInvoke>
507 508 509 510 511 512 513 514 515 516 517 518 519
{
    public bool Equals(PInvoke? x, PInvoke? y)
    {
        if (x == null && y == null)
            return true;
        if (x == null || y == null)
            return false;

        return x.Equals(y);
    }

    public int GetHashCode(PInvoke pinvoke)
        => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
Z
Zoltan Varga 已提交
520
}
521

522
internal sealed class PInvokeCallback
523 524 525 526 527 528 529 530 531
{
    public PInvokeCallback(MethodInfo method)
    {
        Method = method;
    }

    public MethodInfo Method;
    public string? EntryName;
}