#include "./codegen_cuda.h" #include "megbrain/common.h" #include "megbrain/jit/ast_c.h" #include "megbrain/jit/executor_opr.h" #include "megbrain/jit/placeholder_opr.h" #include "megbrain/jit/utils.h" #include "megbrain/opr/tensor_manip.h" #include #if MGB_JIT && MGB_CUDA using namespace mgb; using namespace jit; using namespace ast_c; namespace { using VarNode2AST = ThinHashMap; const char* dtype_to_cstr(DType dtype) { if (dtype == dtype::Float16()) return "__half"; if (dtype == dtype::Float32()) return "float"; mgb_throw(GraphError, "unsupported output dtype %s in JIT fusion", dtype.name()); } std::string gen_fastdiv_offset(size_t nr_inps) { std::string res = ""; char tmp[100]; #define APPEND(fmt...) \ do { \ snprintf(tmp, sizeof(tmp), fmt); \ res += tmp; \ } while (0) for (size_t i = 0; i < nr_inps; ++i) { APPEND("offset_%zu = 0;\n", i); } for (size_t i = 0; i < nr_inps; ++i) { APPEND("tmp_idx = global_idx;\n"); APPEND("#pragma unroll\n"); APPEND("for (int j = {{NDIM}} - 1; j >= 1; --j) {\n"); APPEND("Uint32Fastdiv& shp = " "visitors.m[%zu].m_shape_highdim[j-1];\n", i); res += R"( unsigned int ans_for_one = tmp_idx & ~shp.m_divisor_is_not_1, dfix = tmp_idx + shp.m_inc_dividend, hi32 = __umulhi(dfix, shp.m_mul), ans = hi32 >> shp.m_shift, idx_div = (ans & shp.m_divisor_is_not_1) | ans_for_one; )"; APPEND("offset_%zu += (tmp_idx - idx_div * shp.m_divisor) * " "visitors.m[%zu].m_stride[j];\n", i, i); APPEND("tmp_idx = idx_div;\n"); APPEND("}\n"); APPEND("offset_%zu += tmp_idx * visitors.m[%zu].m_stride[0];\n", i, i); } #undef APPEND return res; } ASTPtr gen_data_ast(size_t input_id, const JITExecutor::Args::Data& n) { auto res = ssprintf( "(static_cast<%s*>(data.inputs[%zu]))[offset_%zu]", dtype_to_cstr(n.layout.dtype), input_id, input_id); return ASTPtr::make(res); } //! generate code to access input values in the kernel void gen_input_code( str_util::StrReplaceMap& replace_map, VarNode2AST& var2ast, const JITExecutor::Args& args, const PlaceholderArray& placeholders) { std::string decl_exps_str, assign_exps_str, decl_fastdiv_offset_str; for (size_t i = 0; i < args.inputs.size(); i++) { ASTPtr elem_var = ASTPtr::make("x" + std::to_string(i)); ASTPtr elem_val = gen_data_ast(i, args.inputs[i]); ASTPtr elem_decl = ASTPtr::make(elem_var, CompNode::DeviceType::CUDA); ASTPtr elem_assign = ASTPtr::make(elem_var, elem_val); var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var; decl_exps_str += elem_decl->code_gen(); assign_exps_str += elem_assign->code_gen(); ASTPtr offset_var = ASTPtr::make("offset_" + std::to_string(i)); ASTPtr offset_decl = ASTPtr::make(offset_var); decl_fastdiv_offset_str += offset_decl->code_gen(); } str_util::append_replace_map( replace_map, {{"{{DECL_fastdiv_offset}}", decl_fastdiv_offset_str}, {"{{DECL_EXPRS}}", decl_exps_str}, {"{{ASSIGN_EXPRS}}", assign_exps_str}}); } ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) { ASTPtrArray cur_inputs; for (auto inp_node : opr->input()) { cur_inputs.push_back(var2ast.at(inp_node)); } if (opr->same_type() || opr->same_type() || opr->same_type()) { // Reduce and GetVarShape occur in grad and would be ignored return {cur_inputs[0]}; } return opr2AST(opr, cur_inputs, CompNode::DeviceType::CUDA).at(0); } } // anonymous namespace std::pair mgb::jit::codegen_cuda( const InternalGraph& internal_graph, const JITExecutor::Args& args, bool copy_param_to_dev) { std::string cuda_kernel = R"( #include struct Uint32Fastdiv { unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; static const unsigned int MAX_DIVIDEND = ~0u - 1; }; template struct ParamElemVisitor { int m_stride[ndim]; //! m_shape_highdim[i] = original_shape[i + 1] Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; static const int NDIM = ndim; }; struct Data { void* inputs[{{NR_INPS}}]; {{OUTPUT_DTYPE}}* output; }; struct PEVisitors { ParamElemVisitor<{{NDIM}}> m[{{NR_INPS}}]; }; template static __forceinline__ __device__ T jit_log_sum_exp(T x, T y) { T a, b; a = x < y ? x : y; b = x < y ? y : x; return T(b + log1pf(expf(a - b))); } )"; cuda_kernel += copy_param_to_dev ? R"( extern "C" __global__ void {{KERNEL_NAME}} (Data* data_ptr, size_t num_elements, PEVisitors* visitors_ptr) { Data data = *data_ptr; PEVisitors visitors = *visitors_ptr; )" : R"( extern "C" __global__ void {{KERNEL_NAME}} (Data data, size_t num_elements, PEVisitors visitors) { )"; cuda_kernel += R"( unsigned int global_idx = blockIdx.x * blockDim.x + threadIdx.x; unsigned int delta = blockDim.x * gridDim.x; unsigned int tmp_idx; {{DECL_EXPRS}} {{INTERNAL_DECL_EXPRS}} {{DECL_fastdiv_offset}} if (global_idx < num_elements) { {{fastdiv_offset}} {{ASSIGN_EXPRS}} {{INTERNAL_ASSIGN_EXPRS}} data.output[global_idx] = {{EXP}}; global_idx += delta; if (global_idx < num_elements) { {{fastdiv_offset}} {{ASSIGN_EXPRS}} {{INTERNAL_ASSIGN_EXPRS}} data.output[global_idx] = {{EXP}}; global_idx += delta; if (global_idx < num_elements) { {{fastdiv_offset}} {{ASSIGN_EXPRS}} {{INTERNAL_ASSIGN_EXPRS}} data.output[global_idx] = {{EXP}}; } } } } )"; VarNode2AST var2ast; str_util::StrReplaceMap source_replace_map; // add inputs to the replace map gen_input_code(source_replace_map, var2ast, args, internal_graph.placeholders()); // add other oprs std::string internal_decl_exps_str, internal_assign_exps_str; size_t cur_opr_cnt = 0; cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { ++cur_opr_cnt; if (opr->same_type()) { return; } ASTPtr elem_var = ASTPtr::make("y" + std::to_string(cur_opr_cnt)); ASTPtr elem_val = gen_opr_ast(opr, var2ast); ASTPtr elem_decl = ASTPtr::make(elem_var, CompNode::DeviceType::CUDA); ASTPtr elem_assign = ASTPtr::make(elem_var, elem_val); var2ast[opr->output(0)] = elem_var; internal_decl_exps_str += elem_decl->code_gen(); internal_assign_exps_str += elem_assign->code_gen(); }}.add(internal_graph.output()); str_util::append_replace_map( source_replace_map, {{"{{NR_INPS}}", std::to_string(args.inputs.size())}, {"{{NDIM}}", std::to_string(args.outputs[0].layout.ndim)}, {"{{fastdiv_offset}}", gen_fastdiv_offset(args.inputs.size())}, {"{{INTERNAL_DECL_EXPRS}}", internal_decl_exps_str}, {"{{INTERNAL_ASSIGN_EXPRS}}", internal_assign_exps_str}, {"{{EXP}}", var2ast.at(internal_graph.output())->code_gen()}, {"{{OUTPUT_DTYPE}}", dtype_to_cstr(args.outputs[0].layout.dtype)}}); str_util::replace_all_pairs_inplace(cuda_kernel, source_replace_map); str_util::replace_all_pairs_inplace(cuda_kernel, source_replace_map); auto kernel_name = ssprintf( "jit_nvrtc_%" PRIx64, XXHash{}.update(cuda_kernel.data(), cuda_kernel.size()).digest()); str_util::replace_all_pairs_inplace( cuda_kernel, {{"{{KERNEL_NAME}}", kernel_name}}); if (ExecutableHelper::keep_interm()) { ExecutableHelper::get().write_file( kernel_name + ".cu", "// " + internal_graph.output()->owner_opr()->name() + "\n" + cuda_kernel); } return {kernel_name, cuda_kernel}; } #endif // MGB_JIT && MGB_CUDA // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}