executable_cpu.cpp 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file src/jit/impl/mlir/executable_cpu.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "./executable_cpu.h"
17 18
#include "./ir/types.h"

19
#include "megbrain/jit/mlir/ir/utils.h"
20

21
#include <mlir/ExecutionEngine/CRunnerUtils.h>
22
#include <mlir/ExecutionEngine/OptUtils.h>
23 24 25 26 27 28

using namespace mgb;
using namespace jit;

namespace {

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
template <typename T, int N>
StridedMemRefType<T, N>* get_strided_memref_type(
        const megdnn::TensorND& tensor) {
    using DescType = StridedMemRefType<T, N>;
    DescType* desc = static_cast<DescType*>(malloc(sizeof(DescType)));
    desc->basePtr = tensor.ptr<T>();
    desc->data = tensor.ptr<T>();
    desc->offset = 0;
    for (size_t i = 0; i < tensor.layout.ndim; i++) {
        desc->sizes[i] = tensor.layout.shape[i];
        desc->strides[i] = tensor.layout.stride[i];
    }
    return desc;
}

44 45 46
template <int N>
void* tensor2memref_dim(const megdnn::TensorND& tensor) {
    switch (tensor.layout.dtype.enumv()) {
47 48 49 50 51
#define cb(_dtype, _type)           \
    case megdnn::DTypeEnum::_dtype: \
        return get_strided_memref_type<_type, N>(tensor);
        FOR_EACH_DNN_DTYPE(cb)
#undef cb
52
        default:
53
            mgb_throw(InternalError, "Unsupported dtype: %s",
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
                      tensor.layout.dtype.name());
    }
    return nullptr;
}

void* tensor2memref(const megdnn::TensorND& tensor) {
    switch (tensor.layout.ndim) {
#define cb(i) \
    case i:   \
        return tensor2memref_dim<i>(tensor)

        cb(1);
        cb(2);
        cb(3);
        cb(4);
        cb(5);
        default:
            mgb_throw(InternalError, "Unsupported ndim, got %zu",
                      tensor.layout.ndim);
#undef cb
    }
}

}  // namespace
MLIRCPUExecutable::MLIRCPUExecutable(mlir::OwningModuleRef& module,
                                     const std::string& kernel_name)
        : m_kernel_name{kernel_name} {
    auto opt_pipeline = mlir::makeOptimizingTransformer(3, 3, 0);
    std::vector<std::string> libs;
    auto&& engine = mlir::ExecutionEngine::create(
            *module, opt_pipeline, llvm::None,
            std::vector<llvm::StringRef>(libs.begin(), libs.end()), true,
            false);
    mgb_assert(engine);
    m_engine = std::move(*engine);
}

void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) {
    auto&& args = fusion_opr->args();
    std::vector<void*> args_array(args.inputs.size() + args.outputs.size());
    std::vector<void*> args_array_pointer(args.inputs.size() +
                                          args.outputs.size());
    size_t idx = 0;
    for (size_t i = 0; i < args.inputs.size(); i++) {
        args_array[idx] =
                tensor2memref({args.inputs[i].from->dev_tensor().raw_ptr(),
                               args.inputs[i].layout});
        args_array_pointer[idx] = &args_array[idx];
        idx++;
    }
    int64_t nr_elements = 0;
    for (size_t i = 0; i < args.outputs.size(); i++) {
        if (nr_elements == 0) {
            nr_elements = args.outputs[i].layout.total_nr_elems();
        } else {
            mgb_assert(static_cast<size_t>(nr_elements) ==
                               args.outputs[i].layout.total_nr_elems(),
                       "The number of elements of outputs mismatch, expected: "
                       "%zu got: %zu(%s)",
                       static_cast<size_t>(nr_elements),
                       args.outputs[i].layout.total_nr_elems(),
                       args.outputs[i].layout.to_string().c_str());
        }
        args_array[idx] =
                tensor2memref({args.outputs[i].from->dev_tensor().raw_ptr(),
                               args.outputs[i].layout});
        args_array_pointer[idx] = &args_array[idx];
        idx++;
    }

124
    args_array_pointer.push_back(&nr_elements);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name;
    auto err = m_engine->invoke(
            adapter_name, llvm::MutableArrayRef<void*>(args_array_pointer));
    if (err) {
        mgb_throw(InternalError, "failed to run MLIR kernel %s\n",
                  m_kernel_name.c_str());
    }

    for (size_t i = 0; i < args_array.size(); i++) {
        free(args_array[i]);
    }
}

MLIRCPUExecutable::~MLIRCPUExecutable() {}

#endif  // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}