提交 e8a5932d 编写于 作者: M Megvii Engine Team

perf(mgb/gopt): optimize impl of reformat builders

GitOrigin-RevId: 844b7e8d393290a6235e70d8455ea2d6d0e124cd
上级 58b8b145
...@@ -19,6 +19,7 @@ using namespace gopt; ...@@ -19,6 +19,7 @@ using namespace gopt;
using Dimension = megdnn::Dimension; using Dimension = megdnn::Dimension;
using NamedTensorShape = megdnn::NamedTensorShape; using NamedTensorShape = megdnn::NamedTensorShape;
// =================== ModifyShapeMixin ====================*/
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
static constexpr uint32_t UNDETERMINED_EXTENT = static constexpr uint32_t UNDETERMINED_EXTENT =
Dimension::UNDETERMINED_EXTENT; Dimension::UNDETERMINED_EXTENT;
...@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { ...@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
const Pattern& pattern) const { const Pattern& pattern) const {
auto src = m_src; auto src = m_src;
auto checker = [src, pattern](VarNode* var) { auto checker = [src, pattern](const VarNodeArray& input) {
mgb_assert(input.size() >= 1);
const auto& var = input.front();
const auto& shp = var->shape(); const auto& shp = var->shape();
if (shp.ndim != src.ndim) if (shp.ndim != src.ndim)
return false; return false;
...@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( ...@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
return checker; return checker;
} }
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { // =================== MakeShapeEmitter ====================*/
MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const {
auto pattern = mixin_analyze(); auto pattern = mixin_analyze();
auto builder = [pattern](VarNode* var) { auto builder = [pattern](const VarNodeArray& input) {
auto sym_var = SymbolVar(var); mgb_assert(input.size() == 1,
"number of input of MakeShapeBuilder should be 1(got:%zu)",
input.size());
auto sym_var = SymbolVar(input.front());
auto shp = opr::GetVarShape::make(sym_var); auto shp = opr::GetVarShape::make(sym_var);
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); };
auto sub = [&shp, &cv](int ax) { auto sub = [&shp, &cv](int ax) {
...@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { ...@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
} }
} }
auto tshp = opr::Concat::make(axs, 0); auto tshp = opr::Concat::make(axs, 0);
auto ovar = opr::Reshape::make(sym_var, tshp); return tshp.node();
};
auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker);
}
// =================== ReshapeEmitter ====================*/
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
auto pattern = mixin_analyze();
auto builder = [pattern](const VarNodeArray& input) {
mgb_assert(input.size() == 2,
"number of input of Reshape should be 2(got:%zu)",
input.size());
auto ovar = opr::Reshape::make(input[0], input[1]);
return ovar.node(); return ovar.node();
}; };
auto checker = mixin_emit_checker(pattern); auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker); return std::make_tuple(builder, checker);
} }
// =================== DimshuffleEmitter ====================*/
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const {
auto&& pattern = m_pattern; auto&& pattern = m_pattern;
auto builder = [pattern](VarNode* var) { auto builder = [pattern](const VarNodeArray& input) {
auto sym_var = SymbolVar(var); mgb_assert(input.size() == 1,
"number of input of Dimshuffle should be 1(got:%zu)",
input.size());
auto sym_var = SymbolVar(input.front());
return opr::Dimshuffle::make(sym_var, pattern).node(); return opr::Dimshuffle::make(sym_var, pattern).node();
}; };
auto checker = [pattern](VarNode* var) { auto checker = [pattern](const VarNodeArray& input) {
return var->shape().ndim == pattern.size(); mgb_assert(input.size() == 1,
"number of input of Dimshuffle should be 1(got:%zu)",
input.size());
return input.front()->shape().ndim == pattern.size();
}; };
return std::make_tuple(builder, checker); return std::make_tuple(builder, checker);
} }
// =================== ReformatEmitter ====================*/
ReformatEmitter::EmitResult ReformatEmitter::emit() const { ReformatEmitter::EmitResult ReformatEmitter::emit() const {
auto ops = analyze(); auto builders = analyze();
auto builder = [ops](VarNode* var) { auto builder = [builders](const VarNodeArray& input) {
VarNode* ovar = var; VarNode *var, *ovar;
for (const auto& i : ops) { var = ovar = input.front();
ovar = i(ovar); if (builders.make_shape1) {
auto shp1 = builders.make_shape1({var});
ovar = builders.reshape1({ovar, shp1});
}
ovar = builders.dimshuffle({ovar});
if (builders.make_shape2) {
auto shp2 = builders.make_shape2({var});
ovar = builders.reshape2({ovar, shp2});
} }
return ovar; return ovar;
}; };
...@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { ...@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const {
return std::make_tuple(builder, checker); return std::make_tuple(builder, checker);
} }
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const {
struct Dim { struct Dim {
Dimension dim; Dimension dim;
int index; int index;
...@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { ...@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const {
i1[i] = src_dims[src_perm[i]].dim; i1[i] = src_dims[src_perm[i]].dim;
i2[i] = src_dims[src_perm[permute[i]]].dim; i2[i] = src_dims[src_perm[permute[i]]].dim;
} }
SmallVector<Builder> ops; UnderlyingBuilders builders;
if (!m_src.eq_shape(i1)) if (!m_src.eq_shape(i1)) {
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); builders.make_shape1 =
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit()));
if (!m_dest.eq_shape(i2)) builders.reshape1 =
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit()));
return ops; }
builders.dimshuffle =
std::move(std::get<0>(DimshuffleEmitter(permute).emit()));
if (!m_dest.eq_shape(i2)) {
builders.make_shape2 =
std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit()));
builders.reshape2 =
std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit()));
}
return builders;
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -20,8 +20,8 @@ namespace gopt { ...@@ -20,8 +20,8 @@ namespace gopt {
class Emitter { class Emitter {
public: public:
using Builder = thin_function<VarNode*(VarNode*)>; using Builder = thin_function<VarNode*(const VarNodeArray&)>;
using Checker = thin_function<bool(VarNode*)>; using Checker = thin_function<bool(const VarNodeArray&)>;
using EmitResult = std::tuple<Builder, Checker>; using EmitResult = std::tuple<Builder, Checker>;
virtual ~Emitter() = default; virtual ~Emitter() = default;
virtual EmitResult emit() const = 0; virtual EmitResult emit() const = 0;
...@@ -39,6 +39,14 @@ protected: ...@@ -39,6 +39,14 @@ protected:
megdnn::NamedTensorShape m_src, m_dest; megdnn::NamedTensorShape m_src, m_dest;
}; };
class MakeShapeEmitter final : public Emitter, ModifyShapeMixin {
public:
MakeShapeEmitter(const megdnn::NamedTensorShape& src,
const megdnn::NamedTensorShape& dest)
: ModifyShapeMixin(src, dest) {}
EmitResult emit() const override;
};
class ReshapeEmitter final : public Emitter, ModifyShapeMixin { class ReshapeEmitter final : public Emitter, ModifyShapeMixin {
public: public:
ReshapeEmitter(const megdnn::NamedTensorShape& src, ReshapeEmitter(const megdnn::NamedTensorShape& src,
...@@ -64,7 +72,10 @@ public: ...@@ -64,7 +72,10 @@ public:
EmitResult emit() const override; EmitResult emit() const override;
private: private:
SmallVector<Builder> analyze() const; struct UnderlyingBuilders {
Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle;
};
UnderlyingBuilders analyze() const;
}; };
} // namespace gopt } // namespace gopt
} // namespace mgb } // namespace mgb
......
...@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { ...@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) {
constexpr size_t N = 12, C = 64, H = 7, W = 7; constexpr size_t N = 12, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape; using NamedTensorShape = megdnn::NamedTensorShape;
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto src = NamedTensorShape::make_named_tensor_shape( auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW32); NamedTensorShape::Format::NCHW32);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple); auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple); auto checker = std::get<1>(tuple);
...@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { ...@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto x = mkvar("x", {N, C / 32, H, W, 32}); auto x = mkvar("x", {N, C / 32, H, W, 32});
EXPECT_TRUE(checker(x.node())); EXPECT_TRUE(checker({x.node()}));
auto x_ = mkvar("x", {N, H, W, C}); auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node())); EXPECT_FALSE(checker({x_.node()}));
auto y1 = SymbolVar(reformat(x.node())); auto y1 = SymbolVar(reformat({x.node()}));
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 2);
auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); auto y2 = SymbolVar(nchw32_to_nchw4(x.node()));
HostTensorND t1, t2; HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)}); auto func1 = graph->compile({make_callback_copy(y1, t1)});
...@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { ...@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
}; };
auto x = mkvar("x", {N, C / 64, H, W, 64}); auto x = mkvar("x", {N, C / 64, H, W, 64});
EXPECT_TRUE(checker(x.node())); EXPECT_TRUE(checker({x.node()}));
auto x_ = mkvar("x", {N, H, W, C}); auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node())); EXPECT_FALSE(checker({x_.node()}));
auto y = SymbolVar(reformat(x.node())); auto y = SymbolVar(reformat({x.node()}));
HostTensorND t; HostTensorND t;
auto func = graph->compile({make_callback_copy(y, t)}); auto func = graph->compile({make_callback_copy(y, t)});
func->execute(); func->execute();
} }
TEST(TestReformatEmitter, EliminateRedudantReshape) {
constexpr size_t N = 16, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape;
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NHWC);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto nchw_to_nhwc = [](VarNode* in) {
auto x = SymbolVar(in);
auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1});
return y.node();
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C, H, W});
EXPECT_TRUE(checker({x.node()}));
auto y1 = SymbolVar(reformat({x.node()}));
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_reshape, 0);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto y2 = SymbolVar(nchw_to_nhwc(x.node()));
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestReformatEmitter, Nchw4ToNchw) {
constexpr size_t N = 12, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape;
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto nchw4_to_nchw = [](VarNode* in) {
auto x = SymbolVar(in);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 4, H, W, 4});
EXPECT_TRUE(checker({x.node()}));
auto y1 = SymbolVar(reformat({x.node()}));
SmallVector<VarNode*> reshapes;
VarNode* dimshuffle;
cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) {
if (o->same_type<opr::Reshape>()) {
reshapes.push_back(o->output(0));
}
if (o->same_type<opr::Dimshuffle>())
dimshuffle = o->output(0);
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(reshapes.size(), 1);
{
gopt::SubGraph graph({y1});
gopt::UniqReaderCheck check(graph);
EXPECT_TRUE(check(reshapes[0]));
EXPECT_TRUE(dimshuffle);
}
auto y2 = SymbolVar(nchw4_to_nchw(x.node()));
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册