提交 7f3f9a94 编写于 作者: M Megvii Engine Team

feat(mgb/core): add shape hint for graph optimization

GitOrigin-RevId: eaad25a7efe61c388f0e45fa780fdbbb12402ae7
上级 d1fbec4f
......@@ -514,6 +514,16 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
optimizer.add_passes_for_optimize_options(options().graph_opt, true);
optimizer.apply_inplace(dest_vars);
if (sopr_stat.has_shape_hint) {
// FIXME(zhangxuanrun): strictly speaking, it could and has to remove
// ShapeHints even they were occured in subgraph
mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph");
// always need remove shape hint
gopt::GraphOptimizer opt;
opt.add_pass<gopt::RemoveShapeHintPass>();
opt.apply_inplace(dest_vars);
}
const OprNodeArray* opr_seq = nullptr;
CompSeqExtraInfo extra_info;
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
......
......@@ -564,6 +564,9 @@ void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) {
sopr_stat->has_virtual_grad = true;
}
#endif
if (sopr_stat && opr->same_type<opr::ShapeHint>()) {
sopr_stat->has_shape_hint = true;
}
}
}
......
......@@ -149,6 +149,7 @@ SymbolVar current_grad_target(ComputingGraph &graph);
struct SpecialOprStat {
bool has_virtual_grad = false;
bool has_shape_hint = false;
};
/*!
......
......@@ -678,6 +678,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass<ParamMergePass>();
add_pass<FuseDeconvCvtPass>();
}
if (inference_opt) {
// remove shape hint after inference optimization
add_pass<RemoveShapeHintPass>();
}
return *this;
}
......
......@@ -1055,4 +1055,30 @@ void PackAllReduceReplacePass::insert_packed_oprs(
#endif // MGB_ENABLE_OPR_MM
/* ======================= RemoveShapeHintPass ====================== */
const char* RemoveShapeHintPass::name() const {
return "remove_shape_hint";
}
void RemoveShapeHintPass::apply(OptState& opt) const {
MIDOUT_B("RemoveShapeHintPass::apply")
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE);
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&](OperatorNodeBase* opr) {
if (auto sh = try_cast_as_op<opr::ShapeHint>(opr)) {
auto inp = rewriter.get_var(sh->input(0));
rewriter.replace_var(sh->output(0), inp,
mgb_cstr_log("remove shape hint"));
return;
}
rewriter.auto_replace_outputs(opr);
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -141,6 +141,12 @@ namespace gopt {
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority);
};
class RemoveShapeHintPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};
} // namespace gopt
} // namespace mgb
......
......@@ -840,4 +840,57 @@ SymbolVar RequireInputDynamicStorage::make(const SymbolVar input,
input.node(), config);
}
/* ===================== ShapeHint ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeHint);
void ShapeHint::scn_do_execute() {
mgb_assert(0);
}
void ShapeHint::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shp = [this](TensorShape& dest, const InpVal&) -> bool {
const TensorShape* inferred = nullptr;
if (cg::is_static_var_shape(input(0))) {
inferred = owner_graph()->static_infer_manager().infer_shape_fallible(input(0));
}
if (inferred) {
dest = *inferred;
if (!dest.eq_shape(m_shape)) {
mgb_log_warn(
"given shape hint on var %s is different from inferred shape, "
"hint %s vs inferred %s", cg::dump_var_info({input(0)}).c_str(),
m_shape.to_string().c_str(), dest.to_string().c_str());
}
} else {
dest = m_shape;
}
return dest.ndim;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0), {m_is_const ? SourceType::CONSTANT : SourceType::MUTABLE, {}, infer_shp});
}
ShapeHint::ShapeHint(VarNode* inp, TensorShape shape,
bool is_const, const OperatorNodeConfig& config)
: Super{inp->owner_graph(), config, "shape_hint", {inp}},
m_shape(shape), m_is_const(is_const) {
add_input({inp});
add_output(None);
}
SymbolVar ShapeHint::make(SymbolVar inp, TensorShape shape,
bool is_const, const OperatorNodeConfig& config) {
return inp.insert_single_output_opr<ShapeHint>(inp.node(), shape, is_const, config);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ShapeHint) {
// since the shape of output(0) could be inferred, no need to
// give hint on out_grad(0)
return out_grad.at(0);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -90,4 +90,15 @@ decl_opr(
params='Empty'
)
decl_raw_opr(
'shape_hint',
desc='a special op providing shape hint only used in graph compilation',
inputs=[Doc('input', 'input var the shape hint was on'),
Doc('shape', 'given hint shape', 'list of int'),
Doc('is_const', 'whether treat given shape as constant', 'bool', 'False')],
body=[
'output = _mgb._Opr.shape_hint(input, shape, is_const, config)'
]
)
# vim: ft=python
......@@ -153,6 +153,17 @@ namespace opr {
#endif
MGB_SEREG_OPR(PersistentOutputStorage, 1);
cg::OperatorNodeBase* opr_shallow_copy_shape_hint(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto &&opr = opr_.cast_final_safe<ShapeHint>();
mgb_assert(inputs.size() == 1);
return ShapeHint::make(inputs[0], opr.shape(), opr.is_const(), config)
.node()->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(ShapeHint, opr_shallow_copy_shape_hint);
} // namespace opr
} // namespace mgb
......
......@@ -512,6 +512,27 @@ public:
const OperatorNodeConfig& config = {});
};
/*
* \brief a special op providing shape hint only used in graph compilation (gopt)
*/
MGB_DEFINE_OPR_CLASS(ShapeHint, cg::SingleCNOperatorNodeBase) // {
TensorShape m_shape;
bool m_is_const;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
public:
ShapeHint(VarNode* inp, const TensorShape shape,
bool is_const, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar inp, const TensorShape shape,
bool is_const=false, const OperatorNodeConfig& config = {});
TensorShape shape() const { return m_shape; }
bool is_const() const { return m_is_const; }
};
} // namespace opr
} // namespace mgb
......
......@@ -12,6 +12,7 @@
#include "megbrain/opr/utility.h"
#include "megbrain/gopt/framework.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/test/helper.h"
using namespace mgb;
......@@ -467,4 +468,64 @@ TEST(TestOprUtility, RequireInputDynamicStorage) {
ASSERT_LT(nr_opr(func), nr0);
}
TEST(TestOprUtility, ShapeHint) {
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Int32> gen_int;
constexpr size_t length = 233;
{ // basic
for (bool dynamic : {false, true}) {
auto host_x = gen_int({length});
auto graph = ComputingGraph::make();
SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x), x_shape_hint, y;
if (dynamic) {
x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{length * 2});
} else {
x_shape_hint = opr::ShapeHint::make(x, TensorShape{length * 2});
}
y = x_shape_hint * 2 + 1;
if (dynamic) {
ASSERT_TRUE(y.shape().eq_shape({length * 2}));
} else {
ASSERT_TRUE(y.shape().eq_shape({length}));
}
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_TRUE(host_y.shape().eq_shape({length}));
for (size_t i = 0; i < length; ++ i) {
ASSERT_EQ((*host_x->ptr<int32_t>()) * 2 + 1, *host_y.ptr<int32_t>());
}
}
}
{ // shallow copy
auto graph = ComputingGraph::make();
auto host_x = gen({length});
SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::ShapeHint::make(x, TensorShape{length * 2}),
x_unknown = opr::MarkDynamicVar::make(x),
y_copy = serialization::copy_opr_shallow(
*y.node()->owner_opr(), {x_unknown.node()})->output(0);
ASSERT_TRUE(y.shape().eq_shape({length}));
ASSERT_TRUE(y_copy.shape().eq_shape({length * 2}));
}
{ // grad
auto host_x = gen({1}), host_y = gen({1});
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{1}),
y_shape_hint = opr::ShapeHint::make(y, TensorShape{1}),
t = x_shape_hint * y_shape_hint;
HostTensorND host_gx, host_gy;
auto func = graph->compile({
make_callback_copy(cg::grad(t, x), host_gx),
make_callback_copy(cg::grad(t, y), host_gy)
});
func->execute();
ASSERT_TRUE(host_gx.shape().is_scalar());
ASSERT_TRUE(host_gy.shape().is_scalar());
ASSERT_FLOAT_EQ(*host_x->ptr<float>(), *host_gy.ptr<float>());
ASSERT_FLOAT_EQ(*host_y->ptr<float>(), *host_gx.ptr<float>());
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册