提交 e715423f 编写于 作者: M Megvii Engine Team

feat(src/gopt): add optpass on arm for fusing typecvt and elemwise to elemwise multi type

GitOrigin-RevId: e6bcbbf91bd24460b2ba2bf7dff3cd3ba13ca7e5
上级 f6d99094
......@@ -644,6 +644,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass<RemoveRedundantTypeCvtPass>();
add_pass<RemoveRedundantCopyPass>();
//! Only arm_common implement Fuse TypeCvt and Elemwise optimized kernel
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA
add_pass<FuseTypecvtElemwisePass>();
#endif
#if MGB_JIT
using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
int jit_opt_level = 0;
......@@ -691,6 +696,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
// remove shape hint after inference optimization
add_pass<RemoveShapeHintPass>();
}
return *this;
}
......
......@@ -2187,4 +2187,144 @@ void ParamMergePass::apply(OptState& opt_state) const {
MIDOUT_E
}
/* ==================== FuseTypecvtElemwisePass ================= */
const char* FuseTypecvtElemwisePass::name() const {
return mgb_cstr_log("Fuse typecvt elemwise pass");
}
void FuseTypecvtElemwisePass::apply(OptState& opt) const {
MIDOUT_B("FuseTypecvtElemwisePass::apply")
opt.set_var_replace_check_flag(
VarReplaceCheckFlag::CHECK_DTYPE | VarReplaceCheckFlag::CHECK_SHAPE);
auto rewriter = opt.graph().make_rewriter();
auto uniq_reader_check = UniqReaderCheck{opt.graph()};
auto try_typecvt_elemwise_fma_i16xf32xf32xf32 = [&rewriter, &uniq_reader_check](
OperatorNodeBase* opr) {
// check elemwise
auto elemwise = try_cast_as_op<opr::Elemwise>(opr);
if (elemwise == nullptr)
return false;
if (elemwise->param().mode != opr::Elemwise::Mode::FUSE_MUL_ADD3)
return false;
bool is_elem_src_f32 =
elemwise->input(0)->dtype().enumv() == DTypeEnum::Float32;
bool is_elem_dst_f32 =
elemwise->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!(is_elem_src_f32 && is_elem_dst_f32))
return false;
if (!uniq_reader_check(elemwise->input(0)))
return false;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(elemwise->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
bool is_typecvt_src_i16 =
typecvt->input(0)->dtype().enumv() == DTypeEnum::Int16;
bool is_typecvt_src_u8 = typecvt->input(0)->dtype().enumv() == DTypeEnum::Uint8;
bool is_typecvt_dst_f32 =
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!((is_typecvt_src_i16 || is_typecvt_src_u8) && is_typecvt_dst_f32))
return false;
SymbolVar new_elem;
auto src0 = rewriter.get_var(typecvt->input(0)),
src1 = rewriter.get_var(elemwise->input(1)),
src2 = rewriter.get_var(elemwise->input(2));
if (is_typecvt_src_i16) {
new_elem = opr::ElemwiseMultiType::make(
{src0, src1, src2},
{opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32},
OperatorNodeConfig{dtype::Float32()});
} else {
new_elem = opr::ElemwiseMultiType::make(
{src0, src1, src2},
{opr::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32},
OperatorNodeConfig{dtype::Float32()});
}
rewriter.replace_var(
opr->output(0), new_elem.node(),
mgb_cstr_log("replace typecvt + elemwise(FUSE_MUL_ADD3)"
"to ElemwiseMultiType(FUSE_MUL_ADD3_INTXxF32xF32xF32)"));
return true;
};
auto try_typecvt_elemwise_mul_i16xf32xf32 = [&rewriter, &uniq_reader_check](
OperatorNodeBase* opr) {
// check elemwise
auto elemwise = try_cast_as_op<opr::Elemwise>(opr);
if (elemwise == nullptr)
return false;
if (elemwise->param().mode != opr::Elemwise::Mode::MUL)
return false;
bool is_elem_src_f32 =
elemwise->input(0)->dtype().enumv() == DTypeEnum::Float32;
bool is_elem_dst_f32 =
elemwise->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!(is_elem_src_f32 && is_elem_dst_f32))
return false;
// maybe src0 or src1
if (!(try_cast_as_op<opr::TypeCvt>(elemwise->input(0)->owner_opr()) ||
try_cast_as_op<opr::TypeCvt>(elemwise->input(1)->owner_opr())))
return false;
int typecvt_src_idx = (try_cast_as_op<opr::TypeCvt>(
elemwise->input(0)->owner_opr()) != nullptr)
? 0
: 1;
int other_src_idx = (typecvt_src_idx == 0) ? 1 : 0;
if (!uniq_reader_check(elemwise->input(typecvt_src_idx)))
return false;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(
elemwise->input(typecvt_src_idx)->owner_opr());
bool is_typecvt_src_i16 =
typecvt->input(0)->dtype().enumv() == DTypeEnum::Int16;
bool is_typecvt_dst_f32 =
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!(is_typecvt_src_i16 && is_typecvt_dst_f32))
return false;
SymbolVar new_elem;
auto src0 = rewriter.get_var(typecvt->input(0)),
src1 = rewriter.get_var(elemwise->input(other_src_idx));
new_elem = opr::ElemwiseMultiType::make(
{src0, src1}, {opr::ElemwiseMultiType::Mode::MUL_INT16xF32xF32},
OperatorNodeConfig{dtype::Float32()});
rewriter.replace_var(
opr->output(0), new_elem.node(),
mgb_cstr_log("replace typecvt + elemwise(MUL)"
"to ElemwiseMultiType(MUL_INT16xF32xF32)"));
return true;
};
auto on_opr = [&try_typecvt_elemwise_fma_i16xf32xf32xf32,
&try_typecvt_elemwise_mul_i16xf32xf32,
&rewriter](OperatorNodeBase* opr) {
if (!try_typecvt_elemwise_fma_i16xf32xf32xf32(opr) &&
!try_typecvt_elemwise_mul_i16xf32xf32(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -200,6 +200,15 @@ public:
void apply(OptState& opt_state) const override;
};
/*!
* \brief Fuse typecvt elemwise
*/
class FuseTypecvtElemwisePass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};
/*!
* \brief tensor format converter to accelerate inference speed on Nvidia
* platform
......
......@@ -1780,6 +1780,138 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) {
}
}
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA
TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase0) {
HostTensorGenerator<dtype::Int16, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 128;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, h, w, c}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn);
auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn);
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto s = mkcvar("s", {1, c, 1, 1});
auto b = mkcvar("b", {1, c, 1, 1});
auto result = opr::Elemwise::make(
{x_f32, s, b}, opr::Elemwise::Param::Mode::FUSE_MUL_ADD3);
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::ElemwiseMultiType>());
ASSERT_EQ(
opr::ElemwiseMultiType::Param::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32,
find_opr<opr::ElemwiseMultiType>(y_opt).param().mode);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
graph->options().graph_opt_level = 2;
auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)});
func_opt->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase1) {
HostTensorGenerator<dtype::Int16, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 128;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, h, w, c}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn);
auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn);
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto s = mkcvar("s", {1, c, 1, 1});
auto result = opr::Elemwise::make({x_f32, s}, opr::Elemwise::Param::Mode::MUL);
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::ElemwiseMultiType>());
ASSERT_EQ(
opr::ElemwiseMultiType::Param::Mode::MUL_INT16xF32xF32,
find_opr<opr::ElemwiseMultiType>(y_opt).param().mode);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
graph->options().graph_opt_level = 2;
auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)});
func_opt->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
TEST(TestGoptInference, FuseTypeCvtAndElemwiseCase2) {
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 128;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, h, w, c}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_nchw = opr::Dimshuffle::make(x, {0, 3, 1, 2}, 4, cn);
auto x_f32 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn);
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto s = mkcvar("s", {1, c, 1, 1});
auto b = mkcvar("b", {1, c, 1, 1});
auto result = opr::Elemwise::make(
{x_f32, s, b}, opr::Elemwise::Param::Mode::FUSE_MUL_ADD3);
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::ElemwiseMultiType>());
ASSERT_EQ(
opr::ElemwiseMultiType::Param::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32,
find_opr<opr::ElemwiseMultiType>(y_opt).param().mode);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
graph->options().graph_opt_level = 2;
auto func_opt = graph->compile({make_callback_copy(y, host_y_opt)});
func_opt->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
#endif
TEST(TestGoptInference, ParamMerge) {
auto cns = load_multiple_xpus(2);
HostTensorGenerator<> gen;
......
......@@ -77,4 +77,12 @@ void ElemwiseMultiType::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
void ElemwiseMultiType::add_input_layout_constraint() {
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) && !MGB_OPENCL && !MGB_CUDA
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
#endif
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -49,6 +49,8 @@ private:
void init_output_dtype() override;
void record_execute_deps(ExecDependencyArray& deps) override;
void add_input_layout_constraint() override;
};
//! deprecated; TODO: remove in megbrain 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册