diff --git a/src/gopt/impl/reformat_emitter.cpp b/src/gopt/impl/reformat_emitter.cpp index 8c9bacfaa60669cfa946ba2bbfc9859baac170a2..9b5b37d84a15a6731fef7e9ca608f9e9a12855ee 100644 --- a/src/gopt/impl/reformat_emitter.cpp +++ b/src/gopt/impl/reformat_emitter.cpp @@ -19,6 +19,7 @@ using namespace gopt; using Dimension = megdnn::Dimension; using NamedTensorShape = megdnn::NamedTensorShape; +// =================== ModifyShapeMixin ====================*/ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; @@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( const Pattern& pattern) const { auto src = m_src; - auto checker = [src, pattern](VarNode* var) { + auto checker = [src, pattern](const VarNodeArray& input) { + mgb_assert(input.size() >= 1); + const auto& var = input.front(); const auto& shp = var->shape(); if (shp.ndim != src.ndim) return false; @@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( return checker; } -ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { +// =================== MakeShapeEmitter ====================*/ +MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const { auto pattern = mixin_analyze(); - auto builder = [pattern](VarNode* var) { - auto sym_var = SymbolVar(var); + auto builder = [pattern](const VarNodeArray& input) { + mgb_assert(input.size() == 1, + "number of input of MakeShapeBuilder should be 1(got:%zu)", + input.size()); + auto sym_var = SymbolVar(input.front()); auto shp = opr::GetVarShape::make(sym_var); auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; auto sub = [&shp, &cv](int ax) { @@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { } } auto tshp = opr::Concat::make(axs, 0); - auto ovar = opr::Reshape::make(sym_var, tshp); + return tshp.node(); + }; + auto checker = mixin_emit_checker(pattern); + return std::make_tuple(builder, checker); +} + +// =================== ReshapeEmitter ====================*/ +ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { + auto pattern = mixin_analyze(); + auto builder = [pattern](const VarNodeArray& input) { + mgb_assert(input.size() == 2, + "number of input of Reshape should be 2(got:%zu)", + input.size()); + auto ovar = opr::Reshape::make(input[0], input[1]); return ovar.node(); }; auto checker = mixin_emit_checker(pattern); return std::make_tuple(builder, checker); } +// =================== DimshuffleEmitter ====================*/ DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { auto&& pattern = m_pattern; - auto builder = [pattern](VarNode* var) { - auto sym_var = SymbolVar(var); + auto builder = [pattern](const VarNodeArray& input) { + mgb_assert(input.size() == 1, + "number of input of Dimshuffle should be 1(got:%zu)", + input.size()); + auto sym_var = SymbolVar(input.front()); return opr::Dimshuffle::make(sym_var, pattern).node(); }; - auto checker = [pattern](VarNode* var) { - return var->shape().ndim == pattern.size(); + auto checker = [pattern](const VarNodeArray& input) { + mgb_assert(input.size() == 1, + "number of input of Dimshuffle should be 1(got:%zu)", + input.size()); + return input.front()->shape().ndim == pattern.size(); }; return std::make_tuple(builder, checker); } +// =================== ReformatEmitter ====================*/ ReformatEmitter::EmitResult ReformatEmitter::emit() const { - auto ops = analyze(); - auto builder = [ops](VarNode* var) { - VarNode* ovar = var; - for (const auto& i : ops) { - ovar = i(ovar); + auto builders = analyze(); + auto builder = [builders](const VarNodeArray& input) { + VarNode *var, *ovar; + var = ovar = input.front(); + if (builders.make_shape1) { + auto shp1 = builders.make_shape1({var}); + ovar = builders.reshape1({ovar, shp1}); + } + ovar = builders.dimshuffle({ovar}); + if (builders.make_shape2) { + auto shp2 = builders.make_shape2({var}); + ovar = builders.reshape2({ovar, shp2}); } return ovar; }; @@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { return std::make_tuple(builder, checker); } -SmallVector ReformatEmitter::analyze() const { +ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { struct Dim { Dimension dim; int index; @@ -196,12 +231,21 @@ SmallVector ReformatEmitter::analyze() const { i1[i] = src_dims[src_perm[i]].dim; i2[i] = src_dims[src_perm[permute[i]]].dim; } - SmallVector ops; - if (!m_src.eq_shape(i1)) - ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); - ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); - if (!m_dest.eq_shape(i2)) - ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); - return ops; + UnderlyingBuilders builders; + if (!m_src.eq_shape(i1)) { + builders.make_shape1 = + std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); + builders.reshape1 = + std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); + } + builders.dimshuffle = + std::move(std::get<0>(DimshuffleEmitter(permute).emit())); + if (!m_dest.eq_shape(i2)) { + builders.make_shape2 = + std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); + builders.reshape2 = + std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); + } + return builders; } // vim: syntax=cpp.doxygen diff --git a/src/gopt/include/megbrain/gopt/reformat_emitter.h b/src/gopt/include/megbrain/gopt/reformat_emitter.h index 9d83cf85083dfe1945f87d623ebefdcf1d6f3fd1..bd62c2153402b02c7da5e53c94cf7626d9f361a2 100644 --- a/src/gopt/include/megbrain/gopt/reformat_emitter.h +++ b/src/gopt/include/megbrain/gopt/reformat_emitter.h @@ -20,8 +20,8 @@ namespace gopt { class Emitter { public: - using Builder = thin_function; - using Checker = thin_function; + using Builder = thin_function; + using Checker = thin_function; using EmitResult = std::tuple; virtual ~Emitter() = default; virtual EmitResult emit() const = 0; @@ -39,6 +39,14 @@ protected: megdnn::NamedTensorShape m_src, m_dest; }; +class MakeShapeEmitter final : public Emitter, ModifyShapeMixin { +public: + MakeShapeEmitter(const megdnn::NamedTensorShape& src, + const megdnn::NamedTensorShape& dest) + : ModifyShapeMixin(src, dest) {} + EmitResult emit() const override; +}; + class ReshapeEmitter final : public Emitter, ModifyShapeMixin { public: ReshapeEmitter(const megdnn::NamedTensorShape& src, @@ -64,7 +72,10 @@ public: EmitResult emit() const override; private: - SmallVector analyze() const; + struct UnderlyingBuilders { + Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle; + }; + UnderlyingBuilders analyze() const; }; } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/reformat_emitter.cpp b/src/gopt/test/reformat_emitter.cpp index ab8ae1d0868b93325d8d1d6254ee8f53f0095301..81901cef1eaccb9a9e5737d9c94825a3e6006221 100644 --- a/src/gopt/test/reformat_emitter.cpp +++ b/src/gopt/test/reformat_emitter.cpp @@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { constexpr size_t N = 12, C = 64, H = 7, W = 7; HostTensorGenerator<> gen; using NamedTensorShape = megdnn::NamedTensorShape; - auto dest = NamedTensorShape::make_named_tensor_shape( - NamedTensorShape::Format::NCHW4); auto src = NamedTensorShape::make_named_tensor_shape( NamedTensorShape::Format::NCHW32); + auto dest = NamedTensorShape::make_named_tensor_shape( + NamedTensorShape::Format::NCHW4); auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); auto reformat = std::get<0>(tuple); auto checker = std::get<1>(tuple); @@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); }; auto x = mkvar("x", {N, C / 32, H, W, 32}); - EXPECT_TRUE(checker(x.node())); + EXPECT_TRUE(checker({x.node()})); auto x_ = mkvar("x", {N, H, W, C}); - EXPECT_FALSE(checker(x_.node())); - auto y1 = SymbolVar(reformat(x.node())); + EXPECT_FALSE(checker({x_.node()})); + auto y1 = SymbolVar(reformat({x.node()})); + size_t nr_shapeof = 0; + size_t nr_reshape = 0; + cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { + if (o->same_type()) + nr_shapeof++; + if (o->same_type()) + nr_reshape++; + }} + .add(y1.node()->owner_opr()); + ASSERT_EQ(nr_shapeof, 1); + ASSERT_EQ(nr_reshape, 2); auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); HostTensorND t1, t2; auto func1 = graph->compile({make_callback_copy(y1, t1)}); @@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); }; auto x = mkvar("x", {N, C / 64, H, W, 64}); - EXPECT_TRUE(checker(x.node())); + EXPECT_TRUE(checker({x.node()})); auto x_ = mkvar("x", {N, H, W, C}); - EXPECT_FALSE(checker(x_.node())); - auto y = SymbolVar(reformat(x.node())); + EXPECT_FALSE(checker({x_.node()})); + auto y = SymbolVar(reformat({x.node()})); HostTensorND t; auto func = graph->compile({make_callback_copy(y, t)}); func->execute(); } + +TEST(TestReformatEmitter, EliminateRedudantReshape) { + constexpr size_t N = 16, C = 64, H = 7, W = 7; + HostTensorGenerator<> gen; + using NamedTensorShape = megdnn::NamedTensorShape; + auto src = NamedTensorShape::make_named_tensor_shape( + NamedTensorShape::Format::NCHW); + auto dest = NamedTensorShape::make_named_tensor_shape( + NamedTensorShape::Format::NHWC); + auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); + auto reformat = std::get<0>(tuple); + auto checker = std::get<1>(tuple); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + auto nchw_to_nhwc = [](VarNode* in) { + auto x = SymbolVar(in); + auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1}); + return y.node(); + }; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + auto x = mkvar("x", {N, C, H, W}); + EXPECT_TRUE(checker({x.node()})); + auto y1 = SymbolVar(reformat({x.node()})); + size_t nr_reshape = 0; + cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) { + if (o->same_type()) + nr_reshape++; + }} + .add(y1.node()->owner_opr()); + ASSERT_EQ(nr_reshape, 0); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y1, t1)}); + func1->execute(); + auto y2 = SymbolVar(nchw_to_nhwc(x.node())); + auto func2 = graph->compile({make_callback_copy(y2, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestReformatEmitter, Nchw4ToNchw) { + constexpr size_t N = 12, C = 64, H = 7, W = 7; + HostTensorGenerator<> gen; + using NamedTensorShape = megdnn::NamedTensorShape; + auto src = NamedTensorShape::make_named_tensor_shape( + NamedTensorShape::Format::NCHW4); + auto dest = NamedTensorShape::make_named_tensor_shape( + NamedTensorShape::Format::NCHW); + auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); + auto reformat = std::get<0>(tuple); + auto checker = std::get<1>(tuple); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + auto nchw4_to_nchw = [](VarNode* in) { + auto x = SymbolVar(in); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp); + return y1.node(); + }; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + auto x = mkvar("x", {N, C / 4, H, W, 4}); + EXPECT_TRUE(checker({x.node()})); + auto y1 = SymbolVar(reformat({x.node()})); + SmallVector reshapes; + VarNode* dimshuffle; + cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) { + if (o->same_type()) { + reshapes.push_back(o->output(0)); + } + if (o->same_type()) + dimshuffle = o->output(0); + }} + .add(y1.node()->owner_opr()); + ASSERT_EQ(reshapes.size(), 1); + { + gopt::SubGraph graph({y1}); + gopt::UniqReaderCheck check(graph); + EXPECT_TRUE(check(reshapes[0])); + EXPECT_TRUE(dimshuffle); + } + auto y2 = SymbolVar(nchw4_to_nchw(x.node())); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y1, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y2, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}