codegen_cuda.cpp 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#include "./codegen_cuda.h"

#include "megbrain/common.h"
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/jit/placeholder_opr.h"
#include "megbrain/jit/utils.h"
#include "megbrain/opr/tensor_manip.h"

#include <cinttypes>

#if MGB_JIT && MGB_CUDA

using namespace mgb;
using namespace jit;
using namespace ast_c;

namespace {

using VarNode2AST = ThinHashMap<VarNode*, ASTPtr>;

const char* dtype_to_cstr(DType dtype) {
    if (dtype == dtype::Float16())
        return "__half";
    if (dtype == dtype::Float32())
        return "float";
M
Megvii Engine Team 已提交
27
    mgb_throw(GraphError, "unsupported output dtype %s in JIT fusion", dtype.name());
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
}

std::string gen_fastdiv_offset(size_t nr_inps) {
    std::string res = "";

    char tmp[100];
#define APPEND(fmt...)                   \
    do {                                 \
        snprintf(tmp, sizeof(tmp), fmt); \
        res += tmp;                      \
    } while (0)

    for (size_t i = 0; i < nr_inps; ++i) {
        APPEND("offset_%zu = 0;\n", i);
    }
    for (size_t i = 0; i < nr_inps; ++i) {
        APPEND("tmp_idx = global_idx;\n");
        APPEND("#pragma unroll\n");
        APPEND("for (int j = {{NDIM}} - 1; j >= 1; --j) {\n");
        APPEND("Uint32Fastdiv& shp = "
               "visitors.m[%zu].m_shape_highdim[j-1];\n",
               i);
        res += R"(
        unsigned int
            ans_for_one = tmp_idx & ~shp.m_divisor_is_not_1,
            dfix = tmp_idx + shp.m_inc_dividend,
            hi32 = __umulhi(dfix, shp.m_mul),
            ans = hi32 >> shp.m_shift,
            idx_div = (ans & shp.m_divisor_is_not_1) | ans_for_one;
        )";
        APPEND("offset_%zu += (tmp_idx - idx_div * shp.m_divisor) * "
               "visitors.m[%zu].m_stride[j];\n",
               i, i);
        APPEND("tmp_idx = idx_div;\n");
        APPEND("}\n");
        APPEND("offset_%zu += tmp_idx * visitors.m[%zu].m_stride[0];\n", i, i);
    }

#undef APPEND
    return res;
}

ASTPtr gen_data_ast(size_t input_id, const JITExecutor::Args::Data& n) {
M
Megvii Engine Team 已提交
71 72 73
    auto res = ssprintf(
            "(static_cast<%s*>(data.inputs[%zu]))[offset_%zu]",
            dtype_to_cstr(n.layout.dtype), input_id, input_id);
74 75 76 77
    return ASTPtr::make<VariableAST>(res);
}

//! generate code to access input values in the kernel
M
Megvii Engine Team 已提交
78 79 80
void gen_input_code(
        str_util::StrReplaceMap& replace_map, VarNode2AST& var2ast,
        const JITExecutor::Args& args, const PlaceholderArray& placeholders) {
81 82 83 84
    std::string decl_exps_str, assign_exps_str, decl_fastdiv_offset_str;
    for (size_t i = 0; i < args.inputs.size(); i++) {
        ASTPtr elem_var = ASTPtr::make<VariableAST>("x" + std::to_string(i));
        ASTPtr elem_val = gen_data_ast(i, args.inputs[i]);
85 86
        ASTPtr elem_decl =
                ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
87 88 89 90 91
        ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
        var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var;
        decl_exps_str += elem_decl->code_gen();
        assign_exps_str += elem_assign->code_gen();

M
Megvii Engine Team 已提交
92
        ASTPtr offset_var = ASTPtr::make<VariableAST>("offset_" + std::to_string(i));
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        ASTPtr offset_decl = ASTPtr::make<DeclIntAST>(offset_var);
        decl_fastdiv_offset_str += offset_decl->code_gen();
    }
    str_util::append_replace_map(
            replace_map, {{"{{DECL_fastdiv_offset}}", decl_fastdiv_offset_str},
                          {"{{DECL_EXPRS}}", decl_exps_str},
                          {"{{ASSIGN_EXPRS}}", assign_exps_str}});
}

ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) {
    ASTPtrArray cur_inputs;
    for (auto inp_node : opr->input()) {
        cur_inputs.push_back(var2ast.at(inp_node));
    }
    if (opr->same_type<opr::Reduce>() || opr->same_type<opr::GetVarShape>() ||
        opr->same_type<opr::Dimshuffle>()) {
        // Reduce and GetVarShape occur in grad and would be ignored
        return {cur_inputs[0]};
    }

113
    return opr2AST(opr, cur_inputs, CompNode::DeviceType::CUDA).at(0);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
}
}  // anonymous namespace

std::pair<std::string, std::string> mgb::jit::codegen_cuda(
        const InternalGraph& internal_graph, const JITExecutor::Args& args,
        bool copy_param_to_dev) {
    std::string cuda_kernel =
            R"(
#include <cuda_fp16.h>

struct Uint32Fastdiv {
    unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift;

    static const unsigned int MAX_DIVIDEND = ~0u - 1;
};

template <int ndim>
struct ParamElemVisitor {
    int m_stride[ndim];

    //! m_shape_highdim[i] = original_shape[i + 1]
    Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
    static const int NDIM = ndim;
};

struct Data {
    void* inputs[{{NR_INPS}}];
    {{OUTPUT_DTYPE}}* output;
};

struct PEVisitors {
    ParamElemVisitor<{{NDIM}}> m[{{NR_INPS}}];
};

template<typename T>
149
static __forceinline__ __device__ T jit_log_sum_exp(T x, T y) {
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(expf(a - b)));
}

)";

    cuda_kernel += copy_param_to_dev ? R"(
extern "C" __global__ void {{KERNEL_NAME}} (Data* data_ptr, size_t num_elements, PEVisitors* visitors_ptr) {
    Data data = *data_ptr;
    PEVisitors visitors = *visitors_ptr;
)"
                                     : R"(
extern "C" __global__ void {{KERNEL_NAME}} (Data data, size_t num_elements,
 PEVisitors visitors) { )";

    cuda_kernel += R"(
    unsigned int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int delta = blockDim.x * gridDim.x;
    unsigned int tmp_idx;

    {{DECL_EXPRS}}
    {{INTERNAL_DECL_EXPRS}}
    {{DECL_fastdiv_offset}}

    if (global_idx < num_elements) {
        {{fastdiv_offset}}
        {{ASSIGN_EXPRS}}
        {{INTERNAL_ASSIGN_EXPRS}}
        data.output[global_idx] = {{EXP}};

        global_idx += delta;
        if (global_idx < num_elements) {
            {{fastdiv_offset}}
            {{ASSIGN_EXPRS}}
            {{INTERNAL_ASSIGN_EXPRS}}
            data.output[global_idx] = {{EXP}};

            global_idx += delta;
            if (global_idx < num_elements) {
                {{fastdiv_offset}}
                {{ASSIGN_EXPRS}}
                {{INTERNAL_ASSIGN_EXPRS}}
                data.output[global_idx] = {{EXP}};
            }
        }
    }
}
)";

    VarNode2AST var2ast;
    str_util::StrReplaceMap source_replace_map;

    // add inputs to the replace map
M
Megvii Engine Team 已提交
205
    gen_input_code(source_replace_map, var2ast, args, internal_graph.placeholders());
206 207 208 209 210 211 212 213 214

    // add other oprs
    std::string internal_decl_exps_str, internal_assign_exps_str;
    size_t cur_opr_cnt = 0;
    cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
        ++cur_opr_cnt;
        if (opr->same_type<JITPlaceholder>()) {
            return;
        }
M
Megvii Engine Team 已提交
215
        ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt));
216
        ASTPtr elem_val = gen_opr_ast(opr, var2ast);
217 218
        ASTPtr elem_decl =
                ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
219 220 221 222
        ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
        var2ast[opr->output(0)] = elem_var;
        internal_decl_exps_str += elem_decl->code_gen();
        internal_assign_exps_str += elem_assign->code_gen();
M
Megvii Engine Team 已提交
223
    }}.add(internal_graph.output());
224 225 226 227 228 229 230 231 232

    str_util::append_replace_map(
            source_replace_map,
            {{"{{NR_INPS}}", std::to_string(args.inputs.size())},
             {"{{NDIM}}", std::to_string(args.outputs[0].layout.ndim)},
             {"{{fastdiv_offset}}", gen_fastdiv_offset(args.inputs.size())},
             {"{{INTERNAL_DECL_EXPRS}}", internal_decl_exps_str},
             {"{{INTERNAL_ASSIGN_EXPRS}}", internal_assign_exps_str},
             {"{{EXP}}", var2ast.at(internal_graph.output())->code_gen()},
M
Megvii Engine Team 已提交
233
             {"{{OUTPUT_DTYPE}}", dtype_to_cstr(args.outputs[0].layout.dtype)}});
234 235 236 237 238 239 240

    str_util::replace_all_pairs_inplace(cuda_kernel, source_replace_map);
    str_util::replace_all_pairs_inplace(cuda_kernel, source_replace_map);

    auto kernel_name = ssprintf(
            "jit_nvrtc_%" PRIx64,
            XXHash{}.update(cuda_kernel.data(), cuda_kernel.size()).digest());
M
Megvii Engine Team 已提交
241 242
    str_util::replace_all_pairs_inplace(
            cuda_kernel, {{"{{KERNEL_NAME}}", kernel_name}});
243 244 245 246 247 248 249 250 251 252 253 254 255 256

    if (ExecutableHelper::keep_interm()) {
        ExecutableHelper::get().write_file(
                kernel_name + ".cu",
                "// " + internal_graph.output()->owner_opr()->name() + "\n" +
                        cuda_kernel);
    }

    return {kernel_name, cuda_kernel};
}

#endif  // MGB_JIT && MGB_CUDA

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}