提交 58b8b145 编写于 作者: M Megvii Engine Team

refactor(mgb/gopt): add checker for reformat emitter

GitOrigin-RevId: 53a8c128f57e05147a0acaffbf52fe55bcbad281
上级 55efc8e1
......@@ -10,8 +10,8 @@
* implied.
*/
#include <numeric>
#include "megbrain/gopt/reformat_emitter.h"
#include <numeric>
#include "megbrain/opr/tensor_manip.h"
using namespace mgb;
......@@ -19,34 +19,7 @@ using namespace gopt;
using Dimension = megdnn::Dimension;
using NamedTensorShape = megdnn::NamedTensorShape;
ReshapeEmitter::Operator ReshapeEmitter::emit() const {
auto pattern = analyze();
auto op = [pattern](VarNode* var) {
auto sym_var = SymbolVar(var);
auto shp = opr::GetVarShape::make(sym_var);
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); };
auto sub = [&shp, &cv](int ax) {
return opr::IndexAt::make(shp, {{0, cv(ax)}});
};
SymbolVarArray axs;
for (auto i : pattern) {
if (std::get<0>(i) >= 0) {
if (std::get<2>(i))
axs.emplace_back(sub(std::get<0>(i)) * std::get<1>(i));
else
axs.emplace_back(sub(std::get<0>(i)) / std::get<1>(i));
} else {
axs.emplace_back(cv(std::get<1>(i)));
}
}
auto tshp = opr::Concat::make(axs, 0);
auto ovar = opr::Reshape::make(sym_var, tshp);
return ovar.node();
};
return op;
}
SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const {
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
static constexpr uint32_t UNDETERMINED_EXTENT =
Dimension::UNDETERMINED_EXTENT;
ThinHashMap<Dimension::Name, int> name2dominant;
......@@ -58,7 +31,7 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const {
}
}
SmallVector<std::tuple<int, int, bool>> pattern(m_dest.ndim);
Pattern pattern(m_dest.ndim);
for (size_t i = 0; i < m_dest.ndim; ++i) {
auto name = m_dest[i].name();
if (m_dest[i].extent() == UNDETERMINED_EXTENT) {
......@@ -74,28 +47,90 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const {
return pattern;
}
DimshuffleEmitter::Operator DimshuffleEmitter::emit() const {
auto pattern = m_pattern;
auto op = [pattern](VarNode* var) {
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
const Pattern& pattern) const {
auto src = m_src;
auto checker = [src, pattern](VarNode* var) {
const auto& shp = var->shape();
if (shp.ndim != src.ndim)
return false;
bool available = true;
for (size_t i = 0; i < shp.ndim; ++i) {
if (src[i].extent() != Dimension::UNDETERMINED_EXTENT) {
available &= (shp[i] == src[i].extent());
}
}
for (auto&& i : pattern) {
int axis, factor;
bool mul;
std::tie(axis, factor, mul) = i;
if (axis >= 0 && !mul) {
available &= (shp[axis] % factor == 0);
}
}
return available;
};
return checker;
}
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
auto pattern = mixin_analyze();
auto builder = [pattern](VarNode* var) {
auto sym_var = SymbolVar(var);
auto shp = opr::GetVarShape::make(sym_var);
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); };
auto sub = [&shp, &cv](int ax) {
return opr::IndexAt::make(shp, {{0, cv(ax)}});
};
SymbolVarArray axs;
for (auto&& i : pattern) {
int axis, factor;
bool mul;
std::tie(axis, factor, mul) = i;
if (axis >= 0) {
if (mul)
axs.emplace_back(sub(axis) * factor);
else
axs.emplace_back(sub(axis) / factor);
} else {
axs.emplace_back(cv(factor));
}
}
auto tshp = opr::Concat::make(axs, 0);
auto ovar = opr::Reshape::make(sym_var, tshp);
return ovar.node();
};
auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker);
}
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const {
auto&& pattern = m_pattern;
auto builder = [pattern](VarNode* var) {
auto sym_var = SymbolVar(var);
return opr::Dimshuffle::make(sym_var, pattern).node();
};
return op;
auto checker = [pattern](VarNode* var) {
return var->shape().ndim == pattern.size();
};
return std::make_tuple(builder, checker);
}
ReformatEmitter::Operator ReformatEmitter::emit() const {
ReformatEmitter::EmitResult ReformatEmitter::emit() const {
auto ops = analyze();
auto op = [ops](VarNode* var) {
auto builder = [ops](VarNode* var) {
VarNode* ovar = var;
for (const auto& o : ops) {
ovar = o(ovar);
for (const auto& i : ops) {
ovar = i(ovar);
}
return ovar;
};
return op;
auto pattern = mixin_analyze();
auto checker = mixin_emit_checker(pattern);
return std::make_tuple(builder, checker);
}
SmallVector<ReformatEmitter::Operator> ReformatEmitter::analyze() const {
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const {
struct Dim {
Dimension dim;
int index;
......@@ -161,12 +196,12 @@ SmallVector<ReformatEmitter::Operator> ReformatEmitter::analyze() const {
i1[i] = src_dims[src_perm[i]].dim;
i2[i] = src_dims[src_perm[permute[i]]].dim;
}
SmallVector<Operator> ops;
SmallVector<Builder> ops;
if (!m_src.eq_shape(i1))
ops.emplace_back(ReshapeEmitter(m_src, i1).emit());
ops.emplace_back(DimshuffleEmitter(permute).emit());
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit()));
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit()));
if (!m_dest.eq_shape(i2))
ops.emplace_back(ReshapeEmitter(i2, m_dest).emit());
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit()));
return ops;
}
// vim: syntax=cpp.doxygen
......@@ -20,45 +20,51 @@ namespace gopt {
class Emitter {
public:
using Operator = thin_function<VarNode*(VarNode*)>;
using Builder = thin_function<VarNode*(VarNode*)>;
using Checker = thin_function<bool(VarNode*)>;
using EmitResult = std::tuple<Builder, Checker>;
virtual ~Emitter() = default;
virtual Operator emit() const = 0;
virtual EmitResult emit() const = 0;
};
class ReshapeEmitter final : public Emitter {
class ModifyShapeMixin {
protected:
using Pattern = SmallVector<std::tuple<int, int, bool>>;
using Checker = Emitter::Checker;
ModifyShapeMixin(const megdnn::NamedTensorShape& src,
const megdnn::NamedTensorShape& dest)
: m_src(src), m_dest(dest) {}
Pattern mixin_analyze() const;
Checker mixin_emit_checker(const Pattern& pattern) const;
megdnn::NamedTensorShape m_src, m_dest;
};
class ReshapeEmitter final : public Emitter, ModifyShapeMixin {
public:
using Operator = typename Emitter::Operator;
ReshapeEmitter(const megdnn::NamedTensorShape& src,
const megdnn::NamedTensorShape& dest)
: m_src{src}, m_dest{dest} {}
Operator emit() const override;
private:
SmallVector<std::tuple<int, int, bool>> analyze() const;
megdnn::NamedTensorShape m_src, m_dest;
: ModifyShapeMixin(src, dest) {}
EmitResult emit() const override;
};
class DimshuffleEmitter final : public Emitter {
public:
using Operator = typename Emitter::Operator;
DimshuffleEmitter(const std::vector<int>& pattern) : m_pattern{pattern} {}
Operator emit() const override;
EmitResult emit() const override;
private:
std::vector<int> m_pattern;
};
class ReformatEmitter final : public Emitter {
class ReformatEmitter final : public Emitter, ModifyShapeMixin {
public:
using Operator = typename Emitter::Operator;
ReformatEmitter(const megdnn::NamedTensorShape& src,
const megdnn::NamedTensorShape& dest)
: m_src{src}, m_dest{dest} {}
Operator emit() const override;
: ModifyShapeMixin(src, dest) {}
EmitResult emit() const override;
private:
SmallVector<Operator> analyze() const;
megdnn::NamedTensorShape m_src, m_dest;
SmallVector<Builder> analyze() const;
};
} // namespace gopt
} // namespace mgb
......
......@@ -25,7 +25,9 @@ TEST(TestReformatEmitter, Basic) {
NamedTensorShape::Format::NCHW4);
auto src = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW32);
auto reformat = gopt::ReformatEmitter(src, dest).emit();
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
......@@ -51,6 +53,9 @@ TEST(TestReformatEmitter, Basic) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 32, H, W, 32});
EXPECT_TRUE(checker(x.node()));
auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node()));
auto y1 = SymbolVar(reformat(x.node()));
auto y2 = SymbolVar(nchw32_to_nchw4(x.node()));
HostTensorND t1, t2;
......@@ -69,7 +74,9 @@ TEST(TestReformatEmitter, MoreComplicated) {
NamedTensorShape::Format::NCHW64);
auto dest = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW88);
auto reformat = gopt::ReformatEmitter(src, dest).emit();
auto&& tuple = gopt::ReformatEmitter(src, dest).emit();
auto reformat = std::get<0>(tuple);
auto checker = std::get<1>(tuple);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
......@@ -77,6 +84,9 @@ TEST(TestReformatEmitter, MoreComplicated) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, C / 64, H, W, 64});
EXPECT_TRUE(checker(x.node()));
auto x_ = mkvar("x", {N, H, W, C});
EXPECT_FALSE(checker(x_.node()));
auto y = SymbolVar(reformat(x.node()));
HostTensorND t;
auto func = graph->compile({make_callback_copy(y, t)});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册