From 7f3f9a94ae79801f3bb47f68128f349e99ebc03a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Dec 2020 18:50:43 +0800 Subject: [PATCH] feat(mgb/core): add shape hint for graph optimization GitOrigin-RevId: eaad25a7efe61c388f0e45fa780fdbbb12402ae7 --- src/core/impl/graph/cg_impl.cpp | 10 ++++ src/core/impl/graph/helper.cpp | 3 ++ src/core/include/megbrain/graph/helper.h | 1 + src/gopt/impl/framework.cpp | 5 ++ src/gopt/impl/misc.cpp | 26 ++++++++++ src/gopt/include/megbrain/gopt/misc.h | 6 +++ src/opr/impl/utility.cpp | 53 ++++++++++++++++++++ src/opr/impl/utility.oprdecl | 11 +++++ src/opr/impl/utility.sereg.h | 11 +++++ src/opr/include/megbrain/opr/utility.h | 21 ++++++++ src/opr/test/utility.cpp | 61 ++++++++++++++++++++++++ 11 files changed, 208 insertions(+) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 46e22e9d..dbead12e 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -514,6 +514,16 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( optimizer.add_passes_for_optimize_options(options().graph_opt, true); optimizer.apply_inplace(dest_vars); + if (sopr_stat.has_shape_hint) { + // FIXME(zhangxuanrun): strictly speaking, it could and has to remove + // ShapeHints even they were occured in subgraph + mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph"); + // always need remove shape hint + gopt::GraphOptimizer opt; + opt.add_pass(); + opt.apply_inplace(dest_vars); + } + const OprNodeArray* opr_seq = nullptr; CompSeqExtraInfo extra_info; cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); diff --git a/src/core/impl/graph/helper.cpp b/src/core/impl/graph/helper.cpp index 7e221250..fe35b9df 100644 --- a/src/core/impl/graph/helper.cpp +++ b/src/core/impl/graph/helper.cpp @@ -564,6 +564,9 @@ void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) { sopr_stat->has_virtual_grad = true; } #endif + if (sopr_stat && opr->same_type()) { + sopr_stat->has_shape_hint = true; + } } } diff --git a/src/core/include/megbrain/graph/helper.h b/src/core/include/megbrain/graph/helper.h index e799de96..0f7f8ad3 100644 --- a/src/core/include/megbrain/graph/helper.h +++ b/src/core/include/megbrain/graph/helper.h @@ -149,6 +149,7 @@ SymbolVar current_grad_target(ComputingGraph &graph); struct SpecialOprStat { bool has_virtual_grad = false; + bool has_shape_hint = false; }; /*! diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index e7cbe94d..52d029b2 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -678,6 +678,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( add_pass(); add_pass(); } + + if (inference_opt) { + // remove shape hint after inference optimization + add_pass(); + } return *this; } diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index 04379772..8c38228e 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -1055,4 +1055,30 @@ void PackAllReduceReplacePass::insert_packed_oprs( #endif // MGB_ENABLE_OPR_MM +/* ======================= RemoveShapeHintPass ====================== */ + +const char* RemoveShapeHintPass::name() const { + return "remove_shape_hint"; +} + +void RemoveShapeHintPass::apply(OptState& opt) const { + MIDOUT_B("RemoveShapeHintPass::apply") + opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE); + auto rewriter = opt.graph().make_rewriter(); + + auto on_opr = [&](OperatorNodeBase* opr) { + if (auto sh = try_cast_as_op(opr)) { + auto inp = rewriter.get_var(sh->input(0)); + rewriter.replace_var(sh->output(0), inp, + mgb_cstr_log("remove shape hint")); + return; + } + rewriter.auto_replace_outputs(opr); + }; + + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + MIDOUT_E +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/misc.h b/src/gopt/include/megbrain/gopt/misc.h index c6aa63c1..f6fc87ad 100644 --- a/src/gopt/include/megbrain/gopt/misc.h +++ b/src/gopt/include/megbrain/gopt/misc.h @@ -141,6 +141,12 @@ namespace gopt { ThinHashMap& replace_map, int priority); }; + class RemoveShapeHintPass final : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + }; + } // namespace gopt } // namespace mgb diff --git a/src/opr/impl/utility.cpp b/src/opr/impl/utility.cpp index 512e9742..0463572e 100644 --- a/src/opr/impl/utility.cpp +++ b/src/opr/impl/utility.cpp @@ -840,4 +840,57 @@ SymbolVar RequireInputDynamicStorage::make(const SymbolVar input, input.node(), config); } +/* ===================== ShapeHint ===================== */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeHint); + +void ShapeHint::scn_do_execute() { + mgb_assert(0); +} + +void ShapeHint::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto infer_shp = [this](TensorShape& dest, const InpVal&) -> bool { + const TensorShape* inferred = nullptr; + if (cg::is_static_var_shape(input(0))) { + inferred = owner_graph()->static_infer_manager().infer_shape_fallible(input(0)); + } + if (inferred) { + dest = *inferred; + if (!dest.eq_shape(m_shape)) { + mgb_log_warn( + "given shape hint on var %s is different from inferred shape, " + "hint %s vs inferred %s", cg::dump_var_info({input(0)}).c_str(), + m_shape.to_string().c_str(), dest.to_string().c_str()); + } + } else { + dest = m_shape; + } + return dest.ndim; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(0), {m_is_const ? SourceType::CONSTANT : SourceType::MUTABLE, {}, infer_shp}); +} + +ShapeHint::ShapeHint(VarNode* inp, TensorShape shape, + bool is_const, const OperatorNodeConfig& config) + : Super{inp->owner_graph(), config, "shape_hint", {inp}}, + m_shape(shape), m_is_const(is_const) { + add_input({inp}); + add_output(None); +} + +SymbolVar ShapeHint::make(SymbolVar inp, TensorShape shape, + bool is_const, const OperatorNodeConfig& config) { + return inp.insert_single_output_opr(inp.node(), shape, is_const, config); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(ShapeHint) { + // since the shape of output(0) could be inferred, no need to + // give hint on out_grad(0) + return out_grad.at(0); +} +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/utility.oprdecl b/src/opr/impl/utility.oprdecl index a3b84243..897ebc3e 100644 --- a/src/opr/impl/utility.oprdecl +++ b/src/opr/impl/utility.oprdecl @@ -90,4 +90,15 @@ decl_opr( params='Empty' ) +decl_raw_opr( + 'shape_hint', + desc='a special op providing shape hint only used in graph compilation', + inputs=[Doc('input', 'input var the shape hint was on'), + Doc('shape', 'given hint shape', 'list of int'), + Doc('is_const', 'whether treat given shape as constant', 'bool', 'False')], + body=[ + 'output = _mgb._Opr.shape_hint(input, shape, is_const, config)' + ] +) + # vim: ft=python diff --git a/src/opr/impl/utility.sereg.h b/src/opr/impl/utility.sereg.h index 159e8bea..eb866117 100644 --- a/src/opr/impl/utility.sereg.h +++ b/src/opr/impl/utility.sereg.h @@ -153,6 +153,17 @@ namespace opr { #endif MGB_SEREG_OPR(PersistentOutputStorage, 1); + + cg::OperatorNodeBase* opr_shallow_copy_shape_hint( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config) { + auto &&opr = opr_.cast_final_safe(); + mgb_assert(inputs.size() == 1); + return ShapeHint::make(inputs[0], opr.shape(), opr.is_const(), config) + .node()->owner_opr(); + } + MGB_REG_OPR_SHALLOW_COPY(ShapeHint, opr_shallow_copy_shape_hint); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/utility.h b/src/opr/include/megbrain/opr/utility.h index 68130911..37783dec 100644 --- a/src/opr/include/megbrain/opr/utility.h +++ b/src/opr/include/megbrain/opr/utility.h @@ -512,6 +512,27 @@ public: const OperatorNodeConfig& config = {}); }; +/* + * \brief a special op providing shape hint only used in graph compilation (gopt) + */ +MGB_DEFINE_OPR_CLASS(ShapeHint, cg::SingleCNOperatorNodeBase) // { + TensorShape m_shape; + bool m_is_const; + + void scn_do_execute() override; + void init_output_static_infer_desc() override; + + public: + ShapeHint(VarNode* inp, const TensorShape shape, + bool is_const, const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar inp, const TensorShape shape, + bool is_const=false, const OperatorNodeConfig& config = {}); + + TensorShape shape() const { return m_shape; } + bool is_const() const { return m_is_const; } +}; + } // namespace opr } // namespace mgb diff --git a/src/opr/test/utility.cpp b/src/opr/test/utility.cpp index 3921c6a1..357c3b0e 100644 --- a/src/opr/test/utility.cpp +++ b/src/opr/test/utility.cpp @@ -12,6 +12,7 @@ #include "megbrain/opr/utility.h" #include "megbrain/gopt/framework.h" #include "megbrain/opr/io.h" +#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/test/helper.h" using namespace mgb; @@ -467,4 +468,64 @@ TEST(TestOprUtility, RequireInputDynamicStorage) { ASSERT_LT(nr_opr(func), nr0); } +TEST(TestOprUtility, ShapeHint) { + HostTensorGenerator<> gen; + HostTensorGenerator gen_int; + constexpr size_t length = 233; + { // basic + for (bool dynamic : {false, true}) { + auto host_x = gen_int({length}); + auto graph = ComputingGraph::make(); + SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x), x_shape_hint, y; + if (dynamic) { + x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{length * 2}); + } else { + x_shape_hint = opr::ShapeHint::make(x, TensorShape{length * 2}); + } + y = x_shape_hint * 2 + 1; + if (dynamic) { + ASSERT_TRUE(y.shape().eq_shape({length * 2})); + } else { + ASSERT_TRUE(y.shape().eq_shape({length})); + } + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_TRUE(host_y.shape().eq_shape({length})); + for (size_t i = 0; i < length; ++ i) { + ASSERT_EQ((*host_x->ptr()) * 2 + 1, *host_y.ptr()); + } + } + } + { // shallow copy + auto graph = ComputingGraph::make(); + auto host_x = gen({length}); + SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::ShapeHint::make(x, TensorShape{length * 2}), + x_unknown = opr::MarkDynamicVar::make(x), + y_copy = serialization::copy_opr_shallow( + *y.node()->owner_opr(), {x_unknown.node()})->output(0); + ASSERT_TRUE(y.shape().eq_shape({length})); + ASSERT_TRUE(y_copy.shape().eq_shape({length * 2})); + } + { // grad + auto host_x = gen({1}), host_y = gen({1}); + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Host2DeviceCopy::make(*graph, host_y), + x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{1}), + y_shape_hint = opr::ShapeHint::make(y, TensorShape{1}), + t = x_shape_hint * y_shape_hint; + HostTensorND host_gx, host_gy; + auto func = graph->compile({ + make_callback_copy(cg::grad(t, x), host_gx), + make_callback_copy(cg::grad(t, y), host_gy) + }); + func->execute(); + ASSERT_TRUE(host_gx.shape().is_scalar()); + ASSERT_TRUE(host_gy.shape().is_scalar()); + ASSERT_FLOAT_EQ(*host_x->ptr(), *host_gy.ptr()); + ASSERT_FLOAT_EQ(*host_y->ptr(), *host_gx.ptr()); + } +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab