提交 e8a5932d 编写于 作者: M Megvii Engine Team

perf(mgb/gopt): optimize impl of reformat builders

GitOrigin-RevId: 844b7e8d393290a6235e70d8455ea2d6d0e124cd
上级 58b8b145
......@@ -19,6 +19,7 @@ using namespace gopt;
using Dimension = megdnn::Dimension;
using NamedTensorShape = megdnn::NamedTensorShape;
// =================== ModifyShapeMixin ====================*/
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
static constexpr uint32_t UNDETERMINED_EXTENT =
Dimension::UNDETERMINED_EXTENT;
......@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
const Pattern& pattern) const {
auto src = m_src;
auto checker = [src, pattern](VarNode* var) {
auto checker = [src, pattern](const VarNodeArray& input) {
mgb_assert(input.size() >= 1);
const auto& var = input.front();
const auto& shp = var->shape();
if (shp.ndim != src.ndim)
return false;
......@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
return checker;
}
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
// =================== MakeShapeEmitter ====================*/
MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const {
auto pattern = mixin_analyze();
auto builder = [pattern](VarNode* var) {
auto sym_var = SymbolVar(var);
auto builder = [pattern](const VarNodeArray& input) {
mgb_assert(input.size() == 1,
"number of input of MakeShapeBuilder should be 1(got:%zu)",
input.size());
auto sym_var = SymbolVar(input.front());
auto shp = opr::GetVarShape::make(sym_var);
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); };
auto sub = [&shp, &cv](int ax) {
......@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
}
}
auto tshp = opr::Concat::make(axs, 0);
auto ovar = opr::Reshape::make(sym_var, tshp);
return tshp.node();
};
auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker);
}
// =================== ReshapeEmitter ====================*/
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
auto pattern = mixin_analyze();
auto builder = [pattern](const VarNodeArray& input) {
mgb_assert(input.size() == 2,
"number of input of Reshape should be 2(got:%zu)",
input.size());
auto ovar = opr::Reshape::make(input[0], input[1]);
return ovar.node();
};
auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker);
}
// =================== DimshuffleEmitter ====================*/
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const {
auto&& pattern = m_pattern;
auto builder = [pattern](VarNode* var) {
auto sym_var = SymbolVar(var);
auto builder = [pattern](const VarNodeArray& input) {
mgb_assert(input.size() == 1,
"number of input of Dimshuffle should be 1(got:%zu)",
input.size());
auto sym_var = SymbolVar(input.front());
return opr::Dimshuffle::make(sym_var, pattern).node();
};
auto checker = [pattern](VarNode* var) {
return var->shape().ndim == pattern.size();
auto checker = [pattern](const VarNodeArray& input) {
mgb_assert(input.size() == 1,
"number of input of Dimshuffle should be 1(got:%zu)",
input.size());
return input.front()->shape().ndim == pattern.size();
};
return std::make_tuple(builder, checker);
}
// =================== ReformatEmitter ====================*/
ReformatEmitter::EmitResult ReformatEmitter::emit() const {
auto ops = analyze();
auto builder = [ops](VarNode* var) {
VarNode* ovar = var;
for (const auto& i : ops) {
ovar = i(ovar);
auto builders = analyze();
auto builder = [builders](const VarNodeArray& input) {
VarNode *var, *ovar;
var = ovar = input.front();
if (builders.make_shape1) {
auto shp1 = builders.make_shape1({var});
ovar = builders.reshape1({ovar, shp1});
}
ovar = builders.dimshuffle({ovar});
if (builders.make_shape2) {
auto shp2 = builders.make_shape2({var});
ovar = builders.reshape2({ovar, shp2});
}
return ovar;
};
......@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const {
return std::make_tuple(builder, checker);
}
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const {
ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const {
struct Dim {
Dimension dim;
int index;
......@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const {
i1[i] = src_dims[src_perm[i]].dim;
i2[i] = src_dims[src_perm[permute[i]]].dim;
}
SmallVector<Builder> ops;
if (!m_src.eq_shape(i1))
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit()));
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit()));
if (!m_dest.eq_shape(i2))
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit()));
return ops;
UnderlyingBuilders builders;
if (!m_src.eq_shape(i1)) {
builders.make_shape1 =
std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit()));
builders.reshape1 =
std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit()));
}
builders.dimshuffle =
std::move(std::get<0>(DimshuffleEmitter(permute).emit()));
if (!m_dest.eq_shape(i2)) {
builders.make_shape2 =
std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit()));
builders.reshape2 =
std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit()));
}
return builders;
}
// vim: syntax=cpp.doxygen
......@@ -20,8 +20,8 @@ namespace gopt {
class Emitter {
public:
using Builder = thin_function<VarNode*(VarNode*)>;
using Checker = thin_function<bool(VarNode*)>;
using Builder = thin_function<VarNode*(const VarNodeArray&)>;
using Checker = thin_function<bool(const VarNodeArray&)>;
using EmitResult = std::tuple<Builder, Checker>;
virtual ~Emitter() = default;
virtual EmitResult emit() const = 0;
......@@ -39,6 +39,14 @@ protected:
megdnn::NamedTensorShape m_src, m_dest;
};
class MakeShapeEmitter final : public Emitter, ModifyShapeMixin {
public:
MakeShapeEmitter(const megdnn::NamedTensorShape& src,
const megdnn::NamedTensorShape& dest)
: ModifyShapeMixin(src, dest) {}
EmitResult emit() const override;
};
class ReshapeEmitter final : public Emitter, ModifyShapeMixin {
public:
ReshapeEmitter(const megdnn::NamedTensorShape& src,
......@@ -64,7 +72,10 @@ public:
EmitResult emit() const override;
private:
SmallVector<Builder> analyze() const;
struct UnderlyingBuilders {
Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle;
};
UnderlyingBuilders analyze() const;
};
} // namespace gopt
} // namespace mgb
......
......@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) {
constexpr size_t N = 12, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape;
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW32);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
......@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 32, H, W, 32});
EXPECT_TRUE(checker(x.node()));
EXPECT_TRUE(checker({x.node()}));
auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node()));
auto y1 = SymbolVar(reformat(x.node()));
EXPECT_FALSE(checker({x_.node()}));
auto y1 = SymbolVar(reformat({x.node()}));
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 2);
auto y2 = SymbolVar(nchw32_to_nchw4(x.node()));
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
......@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 64, H, W, 64});
EXPECT_TRUE(checker(x.node()));
EXPECT_TRUE(checker({x.node()}));
auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node()));
auto y = SymbolVar(reformat(x.node()));
EXPECT_FALSE(checker({x_.node()}));
auto y = SymbolVar(reformat({x.node()}));
HostTensorND t;
auto func = graph->compile({make_callback_copy(y, t)});
func->execute();
}
TEST(TestReformatEmitter, EliminateRedudantReshape) {
constexpr size_t N = 16, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape;
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NHWC);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto nchw_to_nhwc = [](VarNode* in) {
auto x = SymbolVar(in);
auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1});
return y.node();
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C, H, W});
EXPECT_TRUE(checker({x.node()}));
auto y1 = SymbolVar(reformat({x.node()}));
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_reshape, 0);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto y2 = SymbolVar(nchw_to_nhwc(x.node()));
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestReformatEmitter, Nchw4ToNchw) {
constexpr size_t N = 12, C = 64, H = 7, W = 7;
HostTensorGenerator<> gen;
using NamedTensorShape = megdnn::NamedTensorShape;
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW);
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto nchw4_to_nchw = [](VarNode* in) {
auto x = SymbolVar(in);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 4, H, W, 4});
EXPECT_TRUE(checker({x.node()}));
auto y1 = SymbolVar(reformat({x.node()}));
SmallVector<VarNode*> reshapes;
VarNode* dimshuffle;
cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) {
if (o->same_type<opr::Reshape>()) {
reshapes.push_back(o->output(0));
}
if (o->same_type<opr::Dimshuffle>())
dimshuffle = o->output(0);
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(reshapes.size(), 1);
{
gopt::SubGraph graph({y1});
gopt::UniqReaderCheck check(graph);
EXPECT_TRUE(check(reshapes[0]));
EXPECT_TRUE(dimshuffle);
}
auto y2 = SymbolVar(nchw4_to_nchw(x.node()));
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册