diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 997a9f6ec026a405b58ce1f1fe8f67a241f79016..7467cd951f52eb2e1acc211cfee675abdf370e0f 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -644,6 +644,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( add_pass(); add_pass(); + //! Only arm_common implement Fuse TypeCvt and Elemwise optimized kernel +#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA + add_pass(); +#endif + #if MGB_JIT using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; int jit_opt_level = 0; @@ -691,6 +696,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( // remove shape hint after inference optimization add_pass(); } + return *this; } diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index b22a068a8e0b9911000e8ae33fd117d1b347c5b5..7f3c04ac73303f51e3e128d652238e041232bc9a 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -2187,4 +2187,144 @@ void ParamMergePass::apply(OptState& opt_state) const { MIDOUT_E } +/* ==================== FuseTypecvtElemwisePass ================= */ +const char* FuseTypecvtElemwisePass::name() const { + return mgb_cstr_log("Fuse typecvt elemwise pass"); +} + +void FuseTypecvtElemwisePass::apply(OptState& opt) const { + MIDOUT_B("FuseTypecvtElemwisePass::apply") + opt.set_var_replace_check_flag( + VarReplaceCheckFlag::CHECK_DTYPE | VarReplaceCheckFlag::CHECK_SHAPE); + auto rewriter = opt.graph().make_rewriter(); + auto uniq_reader_check = UniqReaderCheck{opt.graph()}; + + auto try_typecvt_elemwise_fma_i16xf32xf32xf32 = [&rewriter, &uniq_reader_check]( + OperatorNodeBase* opr) { + // check elemwise + auto elemwise = try_cast_as_op(opr); + if (elemwise == nullptr) + return false; + if (elemwise->param().mode != opr::Elemwise::Mode::FUSE_MUL_ADD3) + return false; + + bool is_elem_src_f32 = + elemwise->input(0)->dtype().enumv() == DTypeEnum::Float32; + bool is_elem_dst_f32 = + elemwise->output(0)->dtype().enumv() == DTypeEnum::Float32; + + if (!(is_elem_src_f32 && is_elem_dst_f32)) + return false; + if (!uniq_reader_check(elemwise->input(0))) + return false; + + // check typecvt + auto typecvt = try_cast_as_op(elemwise->input(0)->owner_opr()); + if (typecvt == nullptr) + return false; + + bool is_typecvt_src_i16 = + typecvt->input(0)->dtype().enumv() == DTypeEnum::Int16; + bool is_typecvt_src_u8 = typecvt->input(0)->dtype().enumv() == DTypeEnum::Uint8; + bool is_typecvt_dst_f32 = + typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; + + if (!((is_typecvt_src_i16 || is_typecvt_src_u8) && is_typecvt_dst_f32)) + return false; + + SymbolVar new_elem; + auto src0 = rewriter.get_var(typecvt->input(0)), + src1 = rewriter.get_var(elemwise->input(1)), + src2 = rewriter.get_var(elemwise->input(2)); + if (is_typecvt_src_i16) { + new_elem = opr::ElemwiseMultiType::make( + {src0, src1, src2}, + {opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}, + OperatorNodeConfig{dtype::Float32()}); + } else { + new_elem = opr::ElemwiseMultiType::make( + {src0, src1, src2}, + {opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}, + OperatorNodeConfig{dtype::Float32()}); + } + + rewriter.replace_var( + opr->output(0), new_elem.node(), + mgb_cstr_log("replace typecvt + elemwise(FUSE_MUL_ADD3)" + "to ElemwiseMultiType(FUSE_MUL_ADD3_INTXxF32xF32xF32)")); + + return true; + }; + + auto try_typecvt_elemwise_mul_i16xf32xf32 = [&rewriter, &uniq_reader_check]( + OperatorNodeBase* opr) { + // check elemwise + auto elemwise = try_cast_as_op(opr); + if (elemwise == nullptr) + return false; + if (elemwise->param().mode != opr::Elemwise::Mode::MUL) + return false; + + bool is_elem_src_f32 = + elemwise->input(0)->dtype().enumv() == DTypeEnum::Float32; + bool is_elem_dst_f32 = + elemwise->output(0)->dtype().enumv() == DTypeEnum::Float32; + + if (!(is_elem_src_f32 && is_elem_dst_f32)) + return false; + // maybe src0 or src1 + if (!(try_cast_as_op(elemwise->input(0)->owner_opr()) || + try_cast_as_op(elemwise->input(1)->owner_opr()))) + return false; + + int typecvt_src_idx = (try_cast_as_op( + elemwise->input(0)->owner_opr()) != nullptr) + ? 0 + : 1; + + int other_src_idx = (typecvt_src_idx == 0) ? 1 : 0; + + if (!uniq_reader_check(elemwise->input(typecvt_src_idx))) + return false; + + // check typecvt + auto typecvt = try_cast_as_op( + elemwise->input(typecvt_src_idx)->owner_opr()); + + bool is_typecvt_src_i16 = + typecvt->input(0)->dtype().enumv() == DTypeEnum::Int16; + bool is_typecvt_dst_f32 = + typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; + + if (!(is_typecvt_src_i16 && is_typecvt_dst_f32)) + return false; + + SymbolVar new_elem; + auto src0 = rewriter.get_var(typecvt->input(0)), + src1 = rewriter.get_var(elemwise->input(other_src_idx)); + new_elem = opr::ElemwiseMultiType::make( + {src0, src1}, {opr::ElemwiseMultiType::Mode::MUL_INT16xF32xF32}, + OperatorNodeConfig{dtype::Float32()}); + + rewriter.replace_var( + opr->output(0), new_elem.node(), + mgb_cstr_log("replace typecvt + elemwise(MUL)" + "to ElemwiseMultiType(MUL_INT16xF32xF32)")); + + return true; + }; + + auto on_opr = [&try_typecvt_elemwise_fma_i16xf32xf32xf32, + &try_typecvt_elemwise_mul_i16xf32xf32, + &rewriter](OperatorNodeBase* opr) { + if (!try_typecvt_elemwise_fma_i16xf32xf32xf32(opr) && + !try_typecvt_elemwise_mul_i16xf32xf32(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + MIDOUT_E +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 2ee23488f47e66504b892016decfa6ec91ec187b..7c8c1aef288a656a69f04cfda03cb4ee03c88658 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -200,6 +200,15 @@ public: void apply(OptState& opt_state) const override; }; +/*! + * \brief Fuse typecvt elemwise + */ +class FuseTypecvtElemwisePass final : public Pass { +public: + const char* name() const override; + void apply(OptState& opt) const override; +}; + /*! * \brief tensor format converter to accelerate inference speed on Nvidia * platform diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 161d0a20fd0843e9615da21a36d978cd4bf662e4..2f783530824e5decbee8921341ac259145e46548 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1780,6 +1780,138 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { } } +#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA +TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase0) { + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 128; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, h, w, c}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + + auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn); + auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn); + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + auto s = mkcvar("s", {1, c, 1, 1}); + auto b = mkcvar("b", {1, c, 1, 1}); + + auto result = opr::Elemwise::make( + {x_f32, s, b}, opr::Elemwise::Param::Mode::FUSE_MUL_ADD3); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); + + ASSERT_EQ( + opr::ElemwiseMultiType::Param::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32, + find_opr(y_opt).param().mode); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + graph->options().graph_opt_level = 2; + auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)}); + func_opt->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); +} + +TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase1) { + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 128; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, h, w, c}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + + auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn); + auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn); + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + auto s = mkcvar("s", {1, c, 1, 1}); + + auto result = opr::Elemwise::make({x_f32, s}, opr::Elemwise::Param::Mode::MUL); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); + + ASSERT_EQ( + opr::ElemwiseMultiType::Param::Mode::MUL_INT16xF32xF32, + find_opr(y_opt).param().mode); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + graph->options().graph_opt_level = 2; + auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)}); + func_opt->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); +} + +TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase2) { + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 128; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, h, w, c}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + + auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn); + auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn); + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + auto s = mkcvar("s", {1, c, 1, 1}); + auto b = mkcvar("b", {1, c, 1, 1}); + + auto result = opr::Elemwise::make( + {x_f32, s, b}, opr::Elemwise::Param::Mode::FUSE_MUL_ADD3); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); + + ASSERT_EQ( + opr::ElemwiseMultiType::Param::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32, + find_opr(y_opt).param().mode); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + graph->options().graph_opt_level = 2; + auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)}); + func_opt->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); +} +#endif + TEST(TestGoptInference, ParamMerge) { auto cns = load_multiple_xpus(2); HostTensorGenerator<> gen; diff --git a/src/opr/impl/nn_int.cpp b/src/opr/impl/nn_int.cpp index 60fd0b66ff27098df96d9da50903ec74be11625f..e15149fbebd34501289b1e5021ef2bbd5e61ddf4 100644 --- a/src/opr/impl/nn_int.cpp +++ b/src/opr/impl/nn_int.cpp @@ -77,4 +77,12 @@ void ElemwiseMultiType::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +void ElemwiseMultiType::add_input_layout_constraint() { +#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA + for (auto i : input()) { + i->add_layout_constraint_contiguous(); + } +#endif +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/nn_int.h b/src/opr/include/megbrain/opr/nn_int.h index 52fa37dc9f250e2b3aeed52aa1d46f718a31f7be..c8f8f2cc061d4949f93020f1455089fdc89f5985 100644 --- a/src/opr/include/megbrain/opr/nn_int.h +++ b/src/opr/include/megbrain/opr/nn_int.h @@ -49,6 +49,8 @@ private: void init_output_dtype() override; void record_execute_deps(ExecDependencyArray& deps) override; + + void add_input_layout_constraint() override; }; //! deprecated; TODO: remove in megbrain 8