提交 fb13a068 编写于 作者: M Megvii Engine Team

feat(jit/opencl): add OpenCL tiny compiler

algo some misc fix:
* fix jit backends env logic issue: fix jit backends env logic issue
* fix OpenCL prop support image detect logic
* disable OpenCL jit on device do not support image
* if opr is not CD4, OpenCL jit will not fuse it
* fix jit test build with clang and without rtti

GitOrigin-RevId: 9311b270d10b13bd8d0ed7831780ae76cac00af6
上级 ed64f0f6
...@@ -47,10 +47,10 @@ LITE_API inline LiteAlgoSelectStrategy operator|( ...@@ -47,10 +47,10 @@ LITE_API inline LiteAlgoSelectStrategy operator|(
* @param no_profiling_on_shape_change do not re-profile to select best implement * @param no_profiling_on_shape_change do not re-profile to select best implement
* algo when input shape changes (use previous algo) * algo when input shape changes (use previous algo)
* *
* @param jit_level Execute supported operators with JIT (support MLIR, * @param jit_level Execute supported operators with JIT, please check with
* NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level: * MGB_JIT_BACKEND for more details, this value indicates JIT level.
* level 1: for JIT execute with basic elemwise operator * 1: for JIT execute with basic elemwise operator
* level 2: for JIT execute elemwise and reduce operators * 2: for JIT execute elemwise and reduce operators
* *
* @param record_level flags to optimize the inference performance with record the * @param record_level flags to optimize the inference performance with record the
* kernel tasks in first run, hereafter the inference all need is to execute the * kernel tasks in first run, hereafter the inference all need is to execute the
......
...@@ -36,10 +36,10 @@ extern "C" { ...@@ -36,10 +36,10 @@ extern "C" {
* \param no_profiling_on_shape_change do not re-profile to select best impl * \param no_profiling_on_shape_change do not re-profile to select best impl
* algo when input shape changes (use previous algo) * algo when input shape changes (use previous algo)
* *
* \param jit_level Execute supported operators with JIT (support MLIR, * \param jit_level Execute supported operators with JIT, please check with
* NVRTC). Can only be used on Nvidia GPUs, this value indicates JIT level: * MGB_JIT_BACKEND for more details, this value indicates JIT level.
* 1 for basic elemwise opr; * 1: for basic elemwise opr
* 2 for including reduce operator * 2: for including reduce operator
* *
* \param record_level flag optimize the inference performace with record the * \param record_level flag optimize the inference performace with record the
* kernel tasks in first run, hereafter the inference all need to execute the * kernel tasks in first run, hereafter the inference all need to execute the
......
...@@ -744,8 +744,8 @@ DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit"); ...@@ -744,8 +744,8 @@ DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit");
///////////////////////// other options for optimization ///////////////// ///////////////////////// other options for optimization /////////////////
DEFINE_bool( DEFINE_bool(
enable_jit, false, enable_jit, false,
" Execute supported operators with JIT(now only support NVRTC). " " Execute supported operators with JIT, please check with MGB_JIT_BACKEND for "
"Can only be used on Nvidia GPUs"); "more details");
#if MGB_ENABLE_TENSOR_RT #if MGB_ENABLE_TENSOR_RT
DEFINE_bool( DEFINE_bool(
tensorrt, false, tensorrt, false,
......
...@@ -41,8 +41,8 @@ class LiteOptions(Structure): ...@@ -41,8 +41,8 @@ class LiteOptions(Structure):
no_profiling_on_shape_change: do not re-profile to select best implement no_profiling_on_shape_change: do not re-profile to select best implement
algo when input shape changes (use previous algo) algo when input shape changes (use previous algo)
jit_level: Execute supported operators with JIT (support MLIR, jit_level: Execute supported operators with JIT, please check with MGB_JIT_BACKEND
NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level: for more details, this value indicates JIT level:
level 1: for JIT execute with basic elemwise operator level 1: for JIT execute with basic elemwise operator
......
...@@ -97,7 +97,7 @@ void DeviceMemoryAllocator::alloc_dynamic( ...@@ -97,7 +97,7 @@ void DeviceMemoryAllocator::alloc_dynamic(
} }
void DeviceMemoryAllocator::defrag_prealloc_contig( void DeviceMemoryAllocator::defrag_prealloc_contig(
ComputingGraph* graph, CompNode comp_node, ComputingGraph* /*graph*/, CompNode comp_node,
size_t size){MGB_TRY{comp_node.free_device(comp_node.alloc_device(size)); size_t size){MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
} }
MGB_CATCH(MemAllocError&, {}) MGB_CATCH(MemAllocError&, {})
...@@ -574,10 +574,13 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -574,10 +574,13 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
(options().graph_opt.jit || options().graph_opt.jit_config.enabled())) { (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) {
// Deprecated usage added previously. It allows NVRTC JIT optimization // Deprecated usage added previously. It allows NVRTC JIT optimization
// when graph_opt_level is 0. This usage is not recommanded any more. // when graph_opt_level is 0. This usage is not recommanded any more.
unsigned int max_warm = 9;
do {
mgb_log_warn( mgb_log_warn(
"It is not recommanded to enable JIT optimization when " "It is not recommanded to enable JIT optimization when "
"graph_opt_level is 0."); "graph_opt_level is 0, try config graph_opt_level more than 0");
setenv("MGB_JIT_BACKEND", "NVRTC", 1); } while (max_warm-- > 0);
gopt::GraphOptimizer optimizer; gopt::GraphOptimizer optimizer;
optimizer.add_pass<gopt::JITFusionPass>( optimizer.add_pass<gopt::JITFusionPass>(
sopr_stat.has_virtual_grad, options().graph_opt.jit, sopr_stat.has_virtual_grad, options().graph_opt.jit,
......
...@@ -859,11 +859,12 @@ const SeqModifierForSublinearMemory::SeqModifyAction& SeqModifierForSublinearMem ...@@ -859,11 +859,12 @@ const SeqModifierForSublinearMemory::SeqModifyAction& SeqModifierForSublinearMem
msg.push_back('\n'); msg.push_back('\n');
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", m_min_bottleneck * SIZE2MB)); msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", m_min_bottleneck * SIZE2MB));
if (!m_par_modifier->m_config->genetic_nr_iter) { if (!m_par_modifier->m_config->genetic_nr_iter) {
msg.append( msg.append(ssprintf(
ssprintf("\nGenetic algorithm is currently DISABLED, " "\nGenetic algorithm is currently DISABLED, "
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" "set %c%cB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
" to a positive integer to set the number of iterations" " to a positive integer to set the number of iterations"
" in genetic algorithm.\n")); " in genetic algorithm.\n",
'M', 'G'));
} }
mgb_log_debug("%s", msg.c_str()); mgb_log_debug("%s", msg.c_str());
#else #else
...@@ -934,10 +935,11 @@ SeqModifierForSublinearMemory::SeqModifyAction SeqModifierForSublinearMemory:: ...@@ -934,10 +935,11 @@ SeqModifierForSublinearMemory::SeqModifyAction SeqModifierForSublinearMemory::
planner_concur = m_config->num_worker; planner_concur = m_config->num_worker;
} }
mgb_log_debug( std::string msg = ssprintf(
"use %zu threads to search for sublinear memory plan; " "use %zu threads to search for sublinear memory plan; this can be changed "
"this can be changed via MGB_SUBLINEAR_MEMORY_WORKERS env var", "via %c%cB_SUBLINEAR_MEMORY_WORKERS env var",
planner_concur); planner_concur, 'M', 'G');
mgb_log_debug("%s", msg.c_str());
for (auto&& i : m_planner_thread_pool.start(planner_concur)) for (auto&& i : m_planner_thread_pool.start(planner_concur))
m_thread2planner[i].reset(new ModifyActionPlanner{this}); m_thread2planner[i].reset(new ModifyActionPlanner{this});
......
...@@ -41,8 +41,8 @@ The detection is implemented in [impl/fusion_pass.cpp](impl/fusion_pass.cpp), ...@@ -41,8 +41,8 @@ The detection is implemented in [impl/fusion_pass.cpp](impl/fusion_pass.cpp),
the main detection logic is in function *Fusion::Impl::on_opr*. Compared to nnvm the main detection logic is in function *Fusion::Impl::on_opr*. Compared to nnvm
fusion, our fusion logic can fuse more operators into one fusion kernel. fusion, our fusion logic can fuse more operators into one fusion kernel.
For now , JIT just support CUDA, but it has reserved interface to extend other For now , JIT support CUDA by HALIDE or NVRTC, CPU by MLIR, OpenCL by TINYOPENCL,
platforms. also it has reserved interface to extend more platforms.
## How to enable JIT ## How to enable JIT
You can set `graph_opt_level` to 3 to enable JIT. You can set `graph_opt_level` to 3 to enable JIT.
...@@ -58,9 +58,11 @@ cg.set_option('graph_opt_level', 3) ...@@ -58,9 +58,11 @@ cg.set_option('graph_opt_level', 3)
You can set environment variable `MGB_JIT_BACKEND` to select the JIT backend. You can set environment variable `MGB_JIT_BACKEND` to select the JIT backend.
| Backend | Platforms | Reduction support | Kernel Binary Cache | Kernel Reuse | Noncontig Input | | Backend | Platforms | Reduction support | Kernel Binary Cache | Kernel Reuse | Noncontig Input |
|---------|-----------|-------------------|---------------------|--------------|-----------------| |------------|-----------|-------------------|---------------------|--------------|-----------------|
| HALIDE | CUDA | Y | No | Shape | No | | HALIDE | CUDA | Y | No | Shape | No |
| NVRTC | CUDA | N | Via PersistentCache | Bcast type | Monotone | | NVRTC | CUDA | N | Via PersistentCache | Bcast type | Monotone |
| MLIR | CPU | N | NO | Kernel hash | Monotone |
| TINYOPENCL | OpenCL | N | Via OpenCL cache | Kernel hash | Monotone |
To enable fusion of Reduce oprs, set `graph_opt.jit = 2` in graph options. To enable fusion of Reduce oprs, set `graph_opt.jit = 2` in graph options.
......
...@@ -53,16 +53,22 @@ ASTPtr gen_powc(ASTPtr inp, float exp) { ...@@ -53,16 +53,22 @@ ASTPtr gen_powc(ASTPtr inp, float exp) {
return make_call("powf", {inp, exp}); return make_call("powf", {inp, exp});
} }
} // anonymous namespace } // anonymous namespace
const ElemGeneratorMap& ast_c::elem_opr_generator() { const ElemGeneratorMap& ast_c::elem_opr_generator(CompNode::DeviceType device_type) {
#define ENTRY(_mode, _impl) \ #define ENTRY(_mode, _impl) \
{ \ { \
ElemMode::_mode, { \ ElemMode::_mode, { \
[](const ASTPtrArray& inps) -> ASTPtrArray { return {_impl}; } \ [=](const ASTPtrArray& inps, bool is_half) -> ASTPtrArray { \
MGB_MARK_USED_VAR(is_half); \
return {_impl}; \
} \
} \ } \
} }
static ElemGeneratorMap map = {
//! other backends map
static ElemGeneratorMap other_map = {
// unary // unary
ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})), ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})),
ENTRY(ABS, make_call("fabsf", inps)), ENTRY(ABS, make_call("fabsf", inps)),
...@@ -102,7 +108,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { ...@@ -102,7 +108,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)), ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)),
ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]), ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]),
ENTRY(TRUE_DIV, inps[0] / inps[1]), ENTRY(TRUE_DIV, inps[0] / inps[1]),
ENTRY(LOG_SUM_EXP, make_call("mgb_log_sum_exp", {inps[0], inps[1]})), ENTRY(LOG_SUM_EXP, make_call("jit_log_sum_exp", {inps[0], inps[1]})),
ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])), ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])),
ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])), ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])),
ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])), ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])),
...@@ -133,22 +139,28 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { ...@@ -133,22 +139,28 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) / 0.f}) /
6.f), 6.f),
}; };
mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER); mgb_assert(other_map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
return map;
return other_map;
#undef ADD_OPR #undef ADD_OPR
} }
ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs) { ASTPtrArray ast_c::opr2AST(
cg::OperatorNodeBase* opr, const ASTPtrArray& inputs,
CompNode::DeviceType device_type) {
using namespace opr; using namespace opr;
if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) { if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) {
if (check_elem_mode(elem->param().mode)) { if (check_elem_mode(elem->param().mode, device_type)) {
return elem_opr_generator().find(elem->param().mode)->second(inputs); return elem_opr_generator(device_type)
.find(elem->param().mode)
->second(inputs, false);
} }
} }
if (auto powc = gopt::try_cast_as_op<PowC>(opr)) { if (auto powc = gopt::try_cast_as_op<PowC>(opr)) {
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return {gen_powc(inputs[0], powc->param().exp)}; return {gen_powc(inputs[0], powc->param().exp)};
} }
...@@ -157,6 +169,7 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs) ...@@ -157,6 +169,7 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs)
if (imm.valid()) { if (imm.valid()) {
auto dtype = imm->dtype(); auto dtype = imm->dtype();
if (dtype == dtype::Int32{}) { if (dtype == dtype::Int32{}) {
return {ASTPtr::make<IntAST>(imm->get<int>())}; return {ASTPtr::make<IntAST>(imm->get<int>())};
} }
float scalar_value; float scalar_value;
...@@ -169,10 +182,12 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs) ...@@ -169,10 +182,12 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs)
InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]", InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]",
dtype.name()); dtype.name());
} }
return {ASTPtr::make<FloatAST>(scalar_value)};
return {ASTPtr::make<FloatAST>(scalar_value, device_type, false)};
} }
if (opr->same_type<opr::TypeCvt>()) { if (opr->same_type<opr::TypeCvt>()) {
// simply ignore TypeCvt oprs. // simply ignore TypeCvt oprs.
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return inputs; return inputs;
......
...@@ -67,40 +67,50 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { ...@@ -67,40 +67,50 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) {
} }
MGB_LOCK_GUARD(holder->mtx); MGB_LOCK_GUARD(holder->mtx);
auto&& compiler = holder->dev2compiler[comp_node.device_type()]; auto&& compiler = holder->dev2compiler[comp_node.device_type()];
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
mgb_assert(
backend,
"code issue happened, need call config_jit_backends before get compiler");
//! please keep logic with JITFusionPass::Impl::config_jit_backends
if (!compiler) { if (!compiler) {
switch (comp_node.device_type()) { switch (comp_node.device_type()) {
#if MGB_CUDA #if MGB_CUDA
case CompNode::DeviceType::CUDA: case CompNode::DeviceType::CUDA:
#if MGB_JIT_HALIDE #if MGB_JIT_HALIDE
if (!backend || !strcmp(backend, "HALIDE")) { if (!strcmp(backend, "HALIDE")) {
compiler = std::make_unique<HalideCudaCompiler>(); compiler = std::make_unique<HalideCudaCompiler>();
break; break;
} }
#endif #endif
#if MGB_JIT_MLIR #if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) { if (!strcmp(backend, "MLIR")) {
compiler = compiler =
std::make_unique<MLIRCompiler>(CompNode::DeviceType::CUDA); std::make_unique<MLIRCompiler>(CompNode::DeviceType::CUDA);
break; break;
} }
#endif #endif
if (!backend || !strcmp(backend, "NVRTC")) { if (!strcmp(backend, "NVRTC")) {
compiler = std::make_unique<CudaCompiler>(); compiler = std::make_unique<CudaCompiler>();
break; break;
} }
mgb_throw(InternalError, "No compiler support for cuda"); mgb_throw(
InternalError,
"No compiler support for cuda, may caused by build not enable "
"MLIR/HALIDE module or error config jit backend env");
break; break;
#endif #endif
case CompNode::DeviceType::CPU: case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR #if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) { if (!strcmp(backend, "MLIR")) {
compiler = compiler =
std::make_unique<MLIRCompiler>(CompNode::DeviceType::CPU); std::make_unique<MLIRCompiler>(CompNode::DeviceType::CPU);
break; break;
} }
#endif #endif
mgb_throw(InternalError, "No compiler support for cpu"); mgb_throw(
InternalError,
"No compiler support for cpu, may caused by build not enable "
"MLIR module or error config jit backend env");
break; break;
default: default:
mgb_throw( mgb_throw(
......
#include "megbrain/jit/fusion_pass.h" #include "megbrain/jit/fusion_pass.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/jit/ast_c.h" #include "megbrain/jit/ast_c.h"
#include "megbrain/jit/compiler.h" #include "megbrain/jit/compiler.h"
#include "megbrain/jit/internal_graph.h" #include "megbrain/jit/internal_graph.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megdnn/tensor_format.h"
#if MGB_JIT #if MGB_JIT
...@@ -66,6 +68,9 @@ class JITFusionPass::Impl final { ...@@ -66,6 +68,9 @@ class JITFusionPass::Impl final {
return num; return num;
} }
//! config jit backends
void config_jit_backends(CompNode comp_node) const;
public: public:
Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state) Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state)
: m_after_grad{after_grad}, : m_after_grad{after_grad},
...@@ -77,6 +82,57 @@ public: ...@@ -77,6 +82,57 @@ public:
} }
}; };
void JITFusionPass::Impl::config_jit_backends(CompNode comp_node) const {
#define ENV_CB(VALUE) \
if (!backend || !strcmp(backend, VALUE)) { \
if (!backend) { \
mgb_log_debug("config jit default backend to %s", VALUE); \
setenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str(), VALUE, 1); \
} \
break; \
}
auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
if (backend) {
mgb_log_debug("use user config jit backend with: %s", backend);
}
switch (comp_node.device_type()) {
#if MGB_CUDA
// CUDA jit default property: HALIDE > MLIR > NVRTC
case CompNode::DeviceType::CUDA:
#if MGB_JIT_HALIDE
ENV_CB("HALIDE");
#endif
#if MGB_JIT_MLIR
ENV_CB("MLIR");
#endif
ENV_CB("NVRTC");
mgb_throw(
InternalError,
"No compiler support for cuda, may caused by build not enable "
"MLIR/HALIDE module or error config jit backend env");
break;
#endif
// CPU jit only support MLIR now
case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
ENV_CB("MLIR");
#endif
mgb_throw(
InternalError,
"No compiler support for cpu, may caused by build not enable "
"MLIR module or error config jit backend env");
break;
default:
mgb_throw(
InternalError,
"unsupported JIT config: "
"comp_node=%s backend_setting=%s",
comp_node.to_string().c_str(), backend);
}
#undef ENV_CB
}
void JITFusionPass::Impl::detect_fusion() { void JITFusionPass::Impl::detect_fusion() {
std::vector<OperatorNodeBase*> topo_order; std::vector<OperatorNodeBase*> topo_order;
m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) { m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) {
...@@ -86,8 +142,19 @@ void JITFusionPass::Impl::detect_fusion() { ...@@ -86,8 +142,19 @@ void JITFusionPass::Impl::detect_fusion() {
} }
}); });
//! call config_jit_backends as soon as possible
for (auto opr : reverse_adaptor(topo_order)) {
auto&& cn = opr->output(0)->comp_node();
if (cn == CompNode::default_cpu()) {
continue;
}
config_jit_backends(cn);
break;
}
for (auto opr : reverse_adaptor(topo_order)) { for (auto opr : reverse_adaptor(topo_order)) {
if (can_be_fused(opr)) { if (can_be_fused(opr)) {
mgb_log_debug("%s: try process : %s", __FUNCTION__, opr->cname());
process_opr(opr); process_opr(opr);
} }
} }
...@@ -317,11 +384,11 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { ...@@ -317,11 +384,11 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
return false; return false;
} }
//! As MLIR backend has some contraints auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
const char* backend = MGB_GETENV("MGB_JIT_BACKEND"); mgb_assert(
if (!backend) { backend,
backend = "DEFAULT"; "code issue happened, need call config_jit_backends before check opr can "
} "be fused");
// float elemwise // float elemwise
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
bool ret = true; bool ret = true;
...@@ -361,11 +428,15 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { ...@@ -361,11 +428,15 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
#undef FOREACH_ELEMWISE_SKIP_MODE #undef FOREACH_ELEMWISE_SKIP_MODE
} }
#endif // MGB_JIT_MLIR #endif // MGB_JIT_MLIR
return ret && ast_c::check_elem_mode(elem->param().mode) &&
return ret &&
ast_c::check_elem_mode(
elem->param().mode, opr->output(0)->comp_node().device_type()) &&
elem->output(0)->dtype().category() == DTypeCategory::FLOAT; elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
} }
if (strcmp(backend, "MLIR")) { //! TINYOPENCL and MLIR only support elemwise now
if (strcmp(backend, "MLIR") && strcmp(backend, "TINYOPENCL")) {
if (opr->same_type<opr::PowC>()) { if (opr->same_type<opr::PowC>()) {
return true; return true;
} }
......
...@@ -82,7 +82,8 @@ void gen_input_code( ...@@ -82,7 +82,8 @@ void gen_input_code(
for (size_t i = 0; i < args.inputs.size(); i++) { for (size_t i = 0; i < args.inputs.size(); i++) {
ASTPtr elem_var = ASTPtr::make<VariableAST>("x" + std::to_string(i)); ASTPtr elem_var = ASTPtr::make<VariableAST>("x" + std::to_string(i));
ASTPtr elem_val = gen_data_ast(i, args.inputs[i]); ASTPtr elem_val = gen_data_ast(i, args.inputs[i]);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(elem_var); ASTPtr elem_decl =
ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val); ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var; var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var;
decl_exps_str += elem_decl->code_gen(); decl_exps_str += elem_decl->code_gen();
...@@ -109,7 +110,7 @@ ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) { ...@@ -109,7 +110,7 @@ ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) {
return {cur_inputs[0]}; return {cur_inputs[0]};
} }
return opr2AST(opr, cur_inputs).at(0); return opr2AST(opr, cur_inputs, CompNode::DeviceType::CUDA).at(0);
} }
} // anonymous namespace } // anonymous namespace
...@@ -145,7 +146,7 @@ struct PEVisitors { ...@@ -145,7 +146,7 @@ struct PEVisitors {
}; };
template<typename T> template<typename T>
static __forceinline__ __device__ T mgb_log_sum_exp(T x, T y) { static __forceinline__ __device__ T jit_log_sum_exp(T x, T y) {
T a, b; T a, b;
a = x < y ? x : y; a = x < y ? x : y;
b = x < y ? y : x; b = x < y ? y : x;
...@@ -213,7 +214,8 @@ extern "C" __global__ void {{KERNEL_NAME}} (Data data, size_t num_elements, ...@@ -213,7 +214,8 @@ extern "C" __global__ void {{KERNEL_NAME}} (Data data, size_t num_elements,
} }
ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt)); ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt));
ASTPtr elem_val = gen_opr_ast(opr, var2ast); ASTPtr elem_val = gen_opr_ast(opr, var2ast);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(elem_var); ASTPtr elem_decl =
ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val); ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[opr->output(0)] = elem_var; var2ast[opr->output(0)] = elem_var;
internal_decl_exps_str += elem_decl->code_gen(); internal_decl_exps_str += elem_decl->code_gen();
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#if MGB_JIT && MGB_CUDA #if MGB_JIT && MGB_CUDA
#include <dlfcn.h>
#include <nvrtc.h> #include <nvrtc.h>
using namespace mgb; using namespace mgb;
......
#include "./codegen_opencl.h"
#include "./utils.h"
#include "megbrain/common.h"
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/jit/placeholder_opr.h"
#include "megbrain/jit/utils.h"
#include "megbrain/opr/tensor_manip.h"
#include <cinttypes>
#if MGB_JIT && MGB_OPENCL
using namespace mgb;
using namespace jit;
using namespace ast_c;
namespace {
using VarNode2AST = ThinHashMap<VarNode*, ASTPtr>;
//! generate code to access input values in the kernel
void gen_input_code_and_gen_input_data_update(
str_util::StrReplaceMap& replace_map, VarNode2AST& var2ast,
const JITExecutor::Args& args, const PlaceholderArray& placeholders,
bool is_half) {
std::string decl_exps_str, input_data_read_str;
std::string read_image_func = is_half ? "read_imageh" : "read_imagef";
std::string scaler_dec_prefix = is_half ? "__global half* x" : "__global float* x";
auto&& b_info = get_channel_broadcast_info(args);
for (size_t i = 0; i < args.inputs.size(); i++) {
//! gen input args
ASTPtr elem_var_raw =
ASTPtr::make<VariableAST>("x_after_read" + std::to_string(i));
ASTPtr elem_var = ASTPtr::make<VariableAST>(
"__read_only image2d_t x" + std::to_string(i));
ASTPtr elem_var_scalar_offset;
if (LayoutType::SCALAR == b_info[i]) {
elem_var = ASTPtr::make<VariableAST>(scaler_dec_prefix + std::to_string(i));
elem_var_scalar_offset = ASTPtr::make<VariableAST>(
"const uint x_offset" + std::to_string(i));
}
var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var_raw;
decl_exps_str += elem_var->code_gen() + ", ";
if (LayoutType::SCALAR == b_info[i]) {
decl_exps_str += elem_var_scalar_offset->code_gen() + ", ";
}
//! gen input data update
ASTPtr elem_var_raw_input = ASTPtr::make<VariableAST>("x" + std::to_string(i));
elem_var_raw = ASTPtr::make<VariableAST>(
(is_half ? "half4 x_after_read" : "float4 x_after_read") +
std::to_string(i));
std::string coord = "coord";
if (LayoutType::BROADCAST == b_info[i]) {
coord = "coord_b";
}
std::string read_method = read_image_func + "(" +
elem_var_raw_input->code_gen() + ", " + coord + ")";
if (LayoutType::SCALAR == b_info[i]) {
if (is_half) {
read_method = "(half4)(vload_half(x_offset" + std::to_string(i) +
", x" + std::to_string(i) + "))";
} else {
read_method = "(float4)(vload(x_offset" + std::to_string(i) + ", x" +
std::to_string(i) + "))";
}
}
ASTPtr elem_assign = ASTPtr::make<AssignAST>(
elem_var_raw, ASTPtr::make<VariableAST>(read_method));
input_data_read_str += elem_assign->code_gen();
}
str_util::append_replace_map(
replace_map, {
{"{{KERNEL_SRC_ARGS}}", decl_exps_str},
{"{{ASSIGN_EXPRS}}", input_data_read_str},
});
}
ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) {
mgb_assert(
!opr->same_type<opr::Reduce>() && !opr->same_type<opr::GetVarShape>() &&
!opr->same_type<opr::Dimshuffle>() && !opr->same_type<opr::PowC>(),
"OpenCL jit not support Reduce/GetVarShape/Dimshuffle/PowC type now");
ASTPtrArray cur_inputs;
for (auto inp_node : opr->input()) {
cur_inputs.push_back(var2ast.at(inp_node));
}
return opr2AST(opr, cur_inputs, CompNode::DeviceType::OPENCL).at(0);
}
} // anonymous namespace
std::pair<std::string, std::string> mgb::jit::codegen_opencl(
const InternalGraph& internal_graph, const JITExecutor::Args& args) {
std::string opencl_kernel = R"(
__kernel void {{KERNEL_NAME}} (
{{KERNEL_SRC_ARGS}}
__write_only image2d_t dst,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int wc_size,
__private const int hb_size,
__private const uint w_size
) {
#if OPENCL_ENABLE_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP;
int wc = get_global_id(0);
int hb = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (wc >= global_size_dim0 || hb >= global_size_dim1)
return;
#endif
for (; hb < hb_size; hb += global_size_dim1) {
for (; wc < wc_size; wc += global_size_dim0) {
int2 coord = (int2)(wc, hb);
int2 coord_b = (int2)(wc / w_size, 0);
{{INTERNAL_DECL_EXPRS}}
{{ASSIGN_EXPRS}}
{{INTERNAL_ASSIGN_EXPRS}}
{{WRITE_IMAGE}}(dst, coord, {{EXP}});
}
wc = get_global_id(0);
}
}
)";
auto input_dtype = args.inputs[0].layout.dtype;
for (size_t i = 0; i < args.inputs.size(); i++) {
mgb_assert(
args.inputs[i].layout.dtype == input_dtype,
"OpenCL jit all oprs should have same dtype");
}
mgb_assert(
args.outputs.size() == 1 && args.outputs[0].layout.dtype == input_dtype,
"output size should be 1 and output dtype should be same with input");
mgb_assert(
dtype::Float16() == input_dtype || dtype::Float32() == input_dtype,
"OpenCL jit dtype only support float32 or float16, %s not support",
input_dtype.name());
auto is_half = dtype::Float16() == input_dtype;
VarNode2AST var2ast;
str_util::StrReplaceMap source_replace_map;
// add inputs to the replace map
gen_input_code_and_gen_input_data_update(
source_replace_map, var2ast, args, internal_graph.placeholders(), is_half);
// add other oprs
std::string internal_decl_exps_str, internal_assign_exps_str;
std::string write_image_func = is_half ? "write_imageh" : "write_imagef";
size_t cur_opr_cnt = 0;
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
++cur_opr_cnt;
if (opr->same_type<JITPlaceholder>()) {
return;
}
ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt));
ASTPtr elem_val = gen_opr_ast(opr, var2ast);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(
elem_var, CompNode::DeviceType::OPENCL, is_half);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[opr->output(0)] = elem_var;
internal_decl_exps_str += elem_decl->code_gen();
internal_assign_exps_str += elem_assign->code_gen();
}}.add(internal_graph.output());
str_util::append_replace_map(
source_replace_map,
{{"{{INTERNAL_DECL_EXPRS}}", internal_decl_exps_str},
{"{{INTERNAL_ASSIGN_EXPRS}}", internal_assign_exps_str},
{"{{WRITE_IMAGE}}", write_image_func},
{"{{EXP}}", var2ast.at(internal_graph.output())->code_gen()}});
str_util::replace_all_pairs_inplace(opencl_kernel, source_replace_map);
// str_util::replace_all_pairs_inplace(opencl_kernel, source_replace_map);
auto kernel_name = ssprintf(
"jit_opencl_%" PRIx64,
XXHash{}.update(opencl_kernel.data(), opencl_kernel.size()).digest());
str_util::replace_all_pairs_inplace(
opencl_kernel, {{"{{KERNEL_NAME}}", kernel_name}});
return {kernel_name, opencl_kernel};
}
#endif // MGB_JIT && MGB_OPENCL
#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_OPENCL
#include "megbrain/jit/executor_opr.h"
namespace mgb {
namespace jit {
/*!
* \brief generate opencl kernel source code
* \return (kernel name, kernel source)
*/
std::pair<std::string, std::string> codegen_opencl(
const InternalGraph& internal_graph, const JITExecutor::Args& args);
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_OPENCL
#include "megbrain_build_config.h"
#include "megdnn/tensor_format.h"
#if MGB_JIT && MGB_OPENCL
#include "./codegen_opencl.h"
#include "./compiler.h"
#include "./utils.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/rdnn/management.h"
#include "megbrain/utils/timer.h"
using namespace mgb;
using namespace jit;
/* ==================== OpenCLTinyCompiler ===================== */
OpenCLTinyCompiler::OpenCLTinyCompiler(CompNode::DeviceType device_type) {
m_is_debug = ::std::getenv("OPENCL_JIT_DEBUG") ? true : false;
mgb_assert(
CompNode::DeviceType::OPENCL == device_type,
"error init OpenCLTinyCompiler");
}
std::unique_ptr<Executable> OpenCLTinyCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
std::string source, kernel_name;
std::tie(kernel_name, source) = codegen_opencl(graph, args);
if (m_is_debug) {
mgb_log_debug("kernel: name: %s\n%s", kernel_name.c_str(), source.c_str());
}
auto ret = std::make_unique<OpenCLExecutable>(
std::move(source), std::move(kernel_name), m_is_debug);
return ret;
}
size_t OpenCLTinyCompiler::get_nr_workspace_outputs(JITExecutor* opr) const {
MGB_MARK_USED_VAR(opr);
return 0;
}
void OpenCLTinyCompiler::init_workspace_size_infer(JITExecutor* opr) {
MGB_MARK_USED_VAR(opr);
}
/* =================== OpenCLExecutable ==================== */
OpenCLExecutable::OpenCLExecutable(std::string source, std::string name, bool is_debug)
: m_source{std::move(source)}, m_name{std::move(name)}, m_is_debug{is_debug} {}
void OpenCLExecutable::execute(JITExecutor* fusion_opr) {
auto&& cn = fusion_opr->comp_node();
auto& env = CompNodeEnv::from_comp_node(cn).opencl_env();
auto handle = mgb::opr::intl::get_megdnn_handle(cn);
auto mgr = env.opencl_mgr;
auto&& ctx = mgr->context();
auto& queue = mgr->command_queue();
auto&& kernel = megdnn::opencl::OpenCLKernel(handle);
auto& args = fusion_opr->args();
static auto&& prop = megcore::opencl::OpenCLProp(mgr->device());
bool is_adreno = prop.is_adreno();
bool is_mali = prop.is_mali();
auto max_work_group = static_cast<uint32_t>(prop.max_work_group_size());
mgb_assert(
prop.is_support_image(),
"code issue happened, OpenCL jit only support device with support image");
//! for debug
MGB_MARK_USED_VAR(ctx);
MGB_MARK_USED_VAR(queue);
size_t WGSX = 0;
size_t WGSY = 0;
//! create cl args
for (size_t i = 0; i < args.inputs.size(); i++) {
if (TensorFormat::Type::IMAGE2D_PACK4 == args.inputs[i].layout.format.type()) {
WGSX = std::max(
WGSX,
args.inputs[i]
.layout.format.as_impl<megdnn::Image2DPack4TensorFormat>()
.image_width(args.inputs[i].layout));
WGSY = std::max(
WGSY,
args.inputs[i]
.layout.format.as_impl<megdnn::Image2DPack4TensorFormat>()
.image_height(args.inputs[i].layout));
}
}
mgb_assert(WGSX > 0 && WGSY > 0, "invalid tensor for OpenCL jit");
if (m_is_debug) {
mgb_log_debug(
"OpenCLExecutable init input tensor array with size: %zu, init output "
"tensor array with size: %zu",
args.inputs.size(), args.outputs.size());
for (size_t i = 0; i < args.inputs.size(); i++) {
mgb_log_debug(
"input(%zu) dim: %zu %s", i, args.inputs[i].layout.ndim,
args.inputs[i].layout.to_string().c_str());
}
for (size_t i = 0; i < args.outputs.size(); i++) {
mgb_log_debug(
"output(%zu) dim: %zu %s", i, args.outputs[i].layout.ndim,
args.outputs[i].layout.to_string().c_str());
}
}
mgb_assert(
args.outputs.size() == 1, "OpenCL elemwise jit output size should be one");
//! create kernel
std::string compile_options;
kernel.set_meta_data({compile_options, m_source});
kernel.set_kernel_name(m_name);
kernel.build_kernel();
//! set tensor args
for (size_t i = 0; i < args.inputs.size(); i++) {
if (TensorFormat::Type::IMAGE2D_PACK4 == args.inputs[i].layout.format.type()) {
kernel.add_tensor_image_args(
{{args.inputs[i].from->dev_tensor().raw_ptr(),
args.inputs[i].layout}});
} else {
//! scalar default format case
kernel.add_tensor_arg(
{args.inputs[i].from->dev_tensor().raw_ptr(),
args.inputs[i].layout});
}
}
kernel.add_tensor_image_args(
{{args.outputs[0].from->dev_tensor().raw_ptr(), args.outputs[0].layout}});
uint32_t block_w = 1, block_h = 1, dimx = 1, dimy = 1;
auto config_super_parameter = [&] {
if (is_adreno) {
block_w = 1;
dimx = 64;
dimy = 1;
} else if (is_mali) {
block_w = 1;
dimx = 96;
dimy = 1;
} else {
//! unknown gpu case
block_w = 1;
dimx = 64;
dimy = 1;
}
//! float16 case
if (dtype::Float16() == args.inputs[0].layout.dtype) {
dimx *= 2;
}
//! scaling dimx less than gws0, dimy less than gws1
dimx = std::min(dimx, static_cast<uint32_t>((WGSX + block_w - 1) / block_w));
dimy = std::min(dimy, static_cast<uint32_t>((WGSY + block_h - 1) / block_h));
//! scaling dimx * dimy less than device max_work_group
dimx = std::min(
dimx, std::max(static_cast<uint32_t>(1), max_work_group / dimy));
};
config_super_parameter();
//! set other args and config lws and gws
int wc_size = WGSX;
int hb_size = WGSY;
WGSX = (WGSX + block_w - 1) / block_w;
WGSY = (WGSY + block_h - 1) / block_h;
int i_WGSX = safe_int<size_t>(WGSX);
int i_WGSY = safe_int<size_t>(WGSY);
kernel.add_args(
{{&i_WGSX, sizeof(int)},
{&i_WGSY, sizeof(int)},
{&wc_size, sizeof(int)},
{&hb_size, sizeof(int)}});
//! have broadcasted_channel_like_input case
int may_w_size = args.outputs[0].layout[3];
kernel.add_arg({&may_w_size, sizeof(cl_uint)});
mgb_log_debug(
"config OpenCL jit kernel args: lws: (%d %d), i_WGSX: %d, i_WGSY: %d "
"wc_size: %d, hb_size: %d, w_size: %d",
dimx, dimy, i_WGSX, i_WGSY, wc_size, hb_size, may_w_size);
kernel.set_local_size({dimx, dimy});
kernel.set_global_size_divup_consider_uniform_gws({WGSX, WGSY});
//! enqueue kernel
kernel.run();
}
#endif // MGB_OPENCL
#pragma once
#include "megbrain_build_config.h"
#if MGB_OPENCL
#include "megbrain/jit/compiler.h"
namespace mgb {
namespace jit {
/*!
* \brief Executable class for OPENCL
*/
class OpenCLExecutable final : public Executable {
public:
OpenCLExecutable(std::string source, std::string name, bool is_debug);
~OpenCLExecutable() = default;
/*!
* \brief execute
* A Executable instance can be executed by one or more fusion_opr
*/
void execute(JITExecutor* fusion_opr) override final;
private:
const std::string m_source;
const std::string m_name;
bool m_is_debug;
};
/*!
* \brief OpenCL tiny compiler, now only handle elemwise opr and just call DNN CL runtime
*/
class OpenCLTinyCompiler final : public Compiler {
std::unique_ptr<Executable> do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) override;
bool m_is_debug;
public:
OpenCLTinyCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::OPENCL);
Property property() const override {
using F = Property::Flag;
return Property{F::BIND_NDIM | F::BIND_SHAPE, JITFeatureBits::NONE, 64};
}
size_t get_nr_workspace_outputs(JITExecutor* opr) const override;
void init_workspace_size_infer(JITExecutor* opr) override;
};
} // namespace jit
} // namespace mgb
#endif // MGB_OPENCL
#include "./utils.h"
#include <vector>
#if MGB_JIT && MGB_OPENCL
std::vector<LayoutType> get_channel_broadcast_info(
const mgb::jit::JITExecutor::Args& args) {
auto output_dim = args.outputs[0].layout.ndim;
auto& out_layout = args.outputs[0].layout;
mgb_assert(
out_layout.ndim == 5,
"code issue happened, OpenCL jit only support image now");
auto n = out_layout[0];
auto c = out_layout[2] * 4;
auto h = out_layout[1];
auto w = out_layout[3];
std::vector<LayoutType> ret;
for (size_t i = 0; i < args.inputs.size(); i++) {
if (args.inputs[i].layout.is_scalar()) {
ret.push_back(LayoutType::SCALAR);
} else {
auto& in_layout = args.inputs[i].layout;
auto in = in_layout[0];
auto ic = in_layout[2] * 4;
auto ih = in_layout[1];
auto iw = in_layout[3];
mgb_assert(
in_layout.ndim == output_dim && in == n && ic == c,
"invalid args for OpenCL jit");
if (ih == h && iw == w) {
ret.push_back(LayoutType::VEC);
} else {
ret.push_back(LayoutType::BROADCAST);
mgb_assert(ih == 1 && iw == 1, "invalid args for OpenCL jit");
}
}
}
return ret;
}
#endif
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_OPENCL
#include "megbrain/jit/compiler.h"
template <typename T, typename S>
T safe_icast(S val) {
static_assert(
std::is_integral<S>::value && std::is_integral<T>::value, "must be int");
mgb_assert(
val <= static_cast<S>(std::numeric_limits<T>::max()) &&
val >= static_cast<S>(0));
return static_cast<T>(val);
}
template <typename S>
int safe_int(S val) {
return safe_icast<int>(val);
}
enum class LayoutType {
SCALAR = 0,
BROADCAST = 1,
VEC = 2,
};
/*!
* \brief get inputs channel broadcast info
* \param args of mgb::jit::JITExecutor::Args
* \return input idx is channel boardcast
*/
std::vector<LayoutType> get_channel_broadcast_info(
const mgb::jit::JITExecutor::Args& args);
#endif
...@@ -42,11 +42,12 @@ public: ...@@ -42,11 +42,12 @@ public:
inline ASTPtr(int imm); inline ASTPtr(int imm);
inline ASTPtr(float imm); inline ASTPtr(float imm);
inline ASTPtr(float imm, CompNode::DeviceType cn_type, bool is_half);
}; };
using ASTPtrArray = SmallVector<ASTPtr>; using ASTPtrArray = SmallVector<ASTPtr>;
//! function type for generating AST nodes //! function type for generating AST nodes
using AstGenerator = thin_function<ASTPtrArray(const ASTPtrArray&)>; using AstGenerator = thin_function<ASTPtrArray(const ASTPtrArray&, bool is_half)>;
class IntAST : public AST { class IntAST : public AST {
public: public:
...@@ -59,11 +60,20 @@ private: ...@@ -59,11 +60,20 @@ private:
class FloatAST : public AST { class FloatAST : public AST {
public: public:
FloatAST(float val) : m_val(val) {} FloatAST(
inline std::string code_gen() override { return ssprintf("float(%.12e)", m_val); } float val, CompNode::DeviceType cn_type = CompNode::DeviceType::CPU,
bool is_half = false)
: m_val(val), m_cn_type(cn_type), m_is_half(is_half) {}
inline std::string code_gen() override {
mgb_assert(!m_is_half, "code issue, only OpenCL support as half now");
return ssprintf("float(%.12e)", m_val);
}
private: private:
float m_val; float m_val;
CompNode::DeviceType m_cn_type;
bool m_is_half;
}; };
class VariableAST : public AST { class VariableAST : public AST {
...@@ -139,13 +149,20 @@ public: ...@@ -139,13 +149,20 @@ public:
class DeclFloatAST : public AST { class DeclFloatAST : public AST {
public: public:
DeclFloatAST(const ASTPtr& var) : m_var(var) {} DeclFloatAST(
const ASTPtr& var, CompNode::DeviceType cn_type = CompNode::DeviceType::CPU,
bool is_half = false)
: m_var(var), m_cn_type(cn_type), m_is_half(is_half) {}
inline std::string code_gen() override { inline std::string code_gen() override {
mgb_assert(!m_is_half, "code issue, only OpenCL support as half now");
return "float " + m_var->code_gen() + ";"; return "float " + m_var->code_gen() + ";";
} }
private: private:
ASTPtr m_var; ASTPtr m_var;
CompNode::DeviceType m_cn_type;
bool m_is_half;
}; };
class DeclIntAST : public AST { class DeclIntAST : public AST {
...@@ -205,23 +222,29 @@ ASTPtr::ASTPtr(int imm) : m_ptr(std::make_shared<IntAST>(imm)) {} ...@@ -205,23 +222,29 @@ ASTPtr::ASTPtr(int imm) : m_ptr(std::make_shared<IntAST>(imm)) {}
ASTPtr::ASTPtr(float imm) : m_ptr(std::make_shared<FloatAST>(imm)) {} ASTPtr::ASTPtr(float imm) : m_ptr(std::make_shared<FloatAST>(imm)) {}
ASTPtr::ASTPtr(float imm, CompNode::DeviceType cn_type, bool is_half)
: m_ptr(std::make_shared<FloatAST>(imm, cn_type, is_half)) {}
using ElemMode = opr::Elemwise::Mode; using ElemMode = opr::Elemwise::Mode;
using ElemGeneratorMap = ThinHashMap<ElemMode, AstGenerator>; using ElemGeneratorMap = ThinHashMap<ElemMode, AstGenerator>;
//! mapping from elemwise mode to ast node generator //! mapping from elemwise mode to ast node generator
const ElemGeneratorMap& elem_opr_generator(); const ElemGeneratorMap& elem_opr_generator(CompNode::DeviceType type);
static inline bool check_elem_mode(ElemMode mode) { static inline bool check_elem_mode(ElemMode mode, CompNode::DeviceType type) {
return elem_opr_generator().count(mode); return elem_opr_generator(type).count(mode);
} }
/*! /*!
* \brief Generate a AST node from the opr and the given ast inputs * \brief Generate a AST node from the opr and the given ast inputs
* \param opr the opr * \param opr the opr
* \param inputs the AST inputs of the ASTs to be generate * \param inputs the AST inputs of the ASTs to be generate
* \param device_type jit backend cn device type
* \return AST nodes corresponding to opr value outputs * \return AST nodes corresponding to opr value outputs
*/ */
ASTPtrArray opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs); ASTPtrArray opr2AST(
cg::OperatorNodeBase* opr, const ASTPtrArray& inputs,
CompNode::DeviceType device_type);
} // namespace ast_c } // namespace ast_c
} // namespace jit } // namespace jit
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs, const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
static SymbolVar make( MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs, const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
......
...@@ -82,10 +82,10 @@ class InternalGraphGenerator { ...@@ -82,10 +82,10 @@ class InternalGraphGenerator {
void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr); void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr);
public: public:
explicit InternalGraphGenerator(cg::OperatorNodeBase* opr); MGE_WIN_DECLSPEC_FUC explicit InternalGraphGenerator(cg::OperatorNodeBase* opr);
//! generate the graph; this method can be called multiple times //! generate the graph; this method can be called multiple times
InternalGraphPtr generate(); MGE_WIN_DECLSPEC_FUC InternalGraphPtr generate();
/*! /*!
* \brief needed input vars in the original (i.e. outer) graph * \brief needed input vars in the original (i.e. outer) graph
...@@ -120,7 +120,7 @@ public: ...@@ -120,7 +120,7 @@ public:
size_t get_cnt_input_if_add(cg::OperatorNodeBase* opr) const; size_t get_cnt_input_if_add(cg::OperatorNodeBase* opr) const;
//! add an operator into this graph; its outputs must have been added //! add an operator into this graph; its outputs must have been added
void add_opr(cg::OperatorNodeBase* opr); MGE_WIN_DECLSPEC_FUC void add_opr(cg::OperatorNodeBase* opr);
//! output var in the outer graph (i.e. the root node) //! output var in the outer graph (i.e. the root node)
VarNode* output() const { return m_output; } VarNode* output() const { return m_output; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册