From 8182af6eb6546281c284ba032da1353d59b79bac Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 Apr 2022 11:20:35 +0800 Subject: [PATCH] fix(mgb): fix strategy of grad_op and opr_attr GitOrigin-RevId: bb7ab8fa9dd838cf90841ff8a0222661c9eeff04 --- .../megengine/core/tensor/array_method.py | 9 --- imperative/python/megengine/functional/nn.py | 2 + imperative/python/src/tensor_utils.cpp | 8 +-- imperative/src/impl/ops/opr_attr.cpp | 62 ++++++++++++++++++- .../megbrain/imperative/ops/opr_attr.h | 5 ++ imperative/src/test/backward_graph.cpp | 45 ++++++++++++++ src/opr/impl/blas.cpp | 24 ++++--- src/opr/impl/dnn/pooling.cpp | 3 +- src/opr/include/megbrain/opr/blas.h | 4 +- src/opr/include/megbrain/opr/dnn/pooling.h | 4 +- .../opr/search_policy/algo_chooser_helper.h | 2 +- 11 files changed, 138 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index b055f31f2..d072a3286 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -70,15 +70,6 @@ def _matmul( maxdim = dim1 if dim1 > dim2 else dim2 compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) - Strategy = builtin.ops.MatrixMul.Strategy - strategy = Strategy(0) - if _config._benchmark_kernel: - strategy |= Strategy.PROFILE - else: - strategy |= Strategy.HEURISTIC - if _config._deterministic_kernel: - strategy |= Strategy.REPRODUCIBLE - if dim1 == 1 and dim2 == 1: # dispatch to Dot (result,) = apply(builtin.Dot(), inp1, inp2) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 75f84a0de..2d5a49762 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -621,6 +621,7 @@ def max_pool2d( pad_h=padding_h, pad_w=padding_w, mode="max", + strategy=get_execution_strategy(), format=conv_format, ) (output,) = apply(op, inp) @@ -665,6 +666,7 @@ def avg_pool2d( pad_h=padding_h, pad_w=padding_w, mode=mode, + strategy=get_execution_strategy(), format=conv_format, ) (output,) = apply(op, inp) diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 1b6fd9443..a455f8128 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1493,7 +1493,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { py::object _matmul_cpp( py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, - py::handle profile, py::handle determistic) { + py::handle profile, py::handle deterministic) { ::megdnn::param::MatrixMul::ComputeMode mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; if (compute_mode.cast().compare(std::string("float32")) == 0) { @@ -1506,7 +1506,7 @@ py::object _matmul_cpp( } else { cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; } - if (determistic.cast()) { + if (deterministic.cast()) { cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; } std::shared_ptr op = MatrixMul::make( @@ -1523,7 +1523,7 @@ py::object _matmul_cpp( py::object _batched_matmul_cpp( py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, - py::handle profile, py::handle determistic) { + py::handle profile, py::handle deterministic) { ::megdnn::param::MatrixMul::ComputeMode mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; if (compute_mode.cast().compare(std::string("float32")) == 0) { @@ -1536,7 +1536,7 @@ py::object _batched_matmul_cpp( } else { cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; } - if (determistic.cast()) { + if (deterministic.cast()) { cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; } std::shared_ptr op = BatchedMatrixMul::make( diff --git a/imperative/src/impl/ops/opr_attr.cpp b/imperative/src/impl/ops/opr_attr.cpp index ef5dbc73b..dfa09150b 100644 --- a/imperative/src/impl/ops/opr_attr.cpp +++ b/imperative/src/impl/ops/opr_attr.cpp @@ -10,6 +10,10 @@ */ #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/rdnn/profiler.h" #include "megbrain/serialization/opr_load_dump.h" #include "../op_trait.h" @@ -65,6 +69,42 @@ public: const serialization::GraphDumpConfig& config() const { mgb_assert(0); } }; +#define cb(FASTRUN_OPR) \ + megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR( \ + cg::OperatorNodeBase* opr) { \ + auto policy = \ + opr->cast_final().execution_policy_transient(); \ + return policy; \ + } \ + void set_strategy_##FASTRUN_OPR( \ + cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) { \ + auto&& p = opr->cast_final(); \ + p.set_execution_policy(policy); \ + } + +DNN_FOREACH_FASTRUN_OPR(cb) +#undef cb + +typedef thin_function get_func; +typedef thin_function + set_func; + +static const mgb::thin_hash_table::ThinHashMap< + mgb::Typeinfo*, std::pair>& +get_type2policy() { + static mgb::thin_hash_table::ThinHashMap< + mgb::Typeinfo*, std::pair> + sl_type2policy; + static std::once_flag flag; + std::call_once(flag, [&]() { +#define cb(FASTRUN_OPR) \ + sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \ + std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR); + DNN_FOREACH_FASTRUN_OPR(cb) + }); + return std::as_const(sl_type2policy); +} + VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& attr = def.cast_final_safe(); auto config = attr.config; @@ -73,7 +113,12 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto registry = serialization::OprRegistry::find_by_name(attr.type); mgb_assert(registry, "operator %s not found", attr.type.c_str()); OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; - return registry->loader(ctx, inputs, config).usable_output(); + auto opr_with_accessor = registry->loader(ctx, inputs, config); + auto&& opr = opr_with_accessor.opr(); + if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { + get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy); + } + return opr_with_accessor.usable_output(); } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { @@ -84,7 +129,11 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { registry->dumper, "operator %s cannot be serialized", opr->dyn_typeinfo()->name); registry->dumper(ctx, *opr); - return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); + megdnn::param::ExecutionPolicy policy; + if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { + policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); + } + return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); } std::vector> props(const OpDef& def) { @@ -108,6 +157,8 @@ OP_TRAIT_REG(OprAttr, OprAttr) bool OprAttr::is_same_st(const Hashable& rhs_) const { auto&& rhs = static_cast(rhs_); return type == rhs.type && param == rhs.param && + policy.strategy == rhs.policy.strategy && + policy.workspace_limit == rhs.policy.workspace_limit && config.comp_node() == rhs.config.comp_node() && config.output_dtype() == rhs.config.output_dtype(); } @@ -115,7 +166,12 @@ bool OprAttr::is_same_st(const Hashable& rhs_) const { size_t OprAttr::hash() const { return hash_pair_combine( hash_pair_combine( - mgb::hash(type), mgb::hash(static_cast>(param))), + hash_pair_combine( + mgb::hash(type), + mgb::hash(static_cast>(param))), + hash_pair_combine( + static_cast(policy.strategy), + policy.workspace_limit)), config.hash()); } diff --git a/imperative/src/include/megbrain/imperative/ops/opr_attr.h b/imperative/src/include/megbrain/imperative/ops/opr_attr.h index bea939825..55b129414 100644 --- a/imperative/src/include/megbrain/imperative/ops/opr_attr.h +++ b/imperative/src/include/megbrain/imperative/ops/opr_attr.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain/imperative/op_def.h" +#include "megbrain/opr/param_defs.h" namespace mgb { namespace imperative { @@ -38,12 +39,16 @@ public: Type type; Param param; + megdnn::param::ExecutionPolicy policy; cg::OperatorNodeConfig config; OprAttr() = default; OprAttr(const Type& t) : type(t) {} OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) : type(t), param(p), config(c) {} + OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, + const cg::OperatorNodeConfig& c) + : type(t), param(p), policy(ps), config(c) {} std::string repr() const; diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index 8b6b292fe..8f854ee45 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -157,6 +157,51 @@ TEST(TestImperative, BackwardGraphBasic) { } } +TEST(TestImperative, ProfileBackward) { + auto cn = CompNode::load("xpux"); + using Policy = megdnn::param::ExecutionPolicy; + using S = Policy::Strategy; + Policy policy; + policy.strategy = S::PROFILE; + { + megdnn::param::Convolution param; + auto op = std::shared_ptr(Convolution::make(param, policy)); + LogicalTensorDesc inp_desc = { + TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; + LogicalTensorDesc weight_desc = { + TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn}; + auto bg = OpDef::make_backward_graph( + *op, {inp_desc, weight_desc}, {true, false}, {true}); + auto&& bop = (bg.graph.exprs.at(0)).op; + auto&& attr = bop->cast_final_safe(); + // attr.type = ConvolutionBackwardDataV2 + mgb_assert(attr.policy.strategy == S::PROFILE); + } + { + megdnn::param::Pooling param; + auto op = std::shared_ptr(Pooling::make(param, policy)); + LogicalTensorDesc inp_desc = { + TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; + auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true}); + auto&& bop = (bg.graph.exprs.at(0)).op; + auto&& attr = bop->cast_final_safe(); + // attr.type = PoolingBackwardV1 + mgb_assert(attr.policy.strategy == S::PROFILE); + } + { + megdnn::param::MatrixMul param; + auto op = std::shared_ptr(MatrixMul::make(param, policy, 2, 2)); + LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn}; + LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn}; + auto bg = OpDef::make_backward_graph( + *op, {inp1_desc, inp2_desc}, {true, false}, {true}); + auto&& bop = (bg.graph.exprs.at(0)).op; + auto&& attr = bop->cast_final_safe(); + // attr.type = MatrixMulV2 + mgb_assert(attr.policy.strategy == S::PROFILE); + } +} + TEST(TestImperative, BackwardGraphIdentity) { HostTensorGenerator<> gen; auto host_a = gen({42}), host_dc = gen({42}); diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index b641d4429..1ecde974b 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -185,17 +185,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { if (wrt_idx == 0) { // A * B = C, A' = C' * Bt if (opr.param().transposeA) { - grad = MatrixMul::make(i1, og, {opr.param().transposeB, true}); + grad = MatrixMul::make( + i1, og, {opr.param().transposeB, true}, opr.execution_policy()); } else { - grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB}); + grad = MatrixMul::make( + og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); } } else { mgb_assert(wrt_idx == 1); // A * B = C, B' = At * C' if (opr.param().transposeB) { - grad = MatrixMul::make(og, i0, {true, opr.param().transposeA}); + grad = MatrixMul::make( + og, i0, {true, opr.param().transposeA}, opr.execution_policy()); } else { - grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false}); + grad = MatrixMul::make( + i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); } } return grad.node(); @@ -358,17 +362,21 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { if (wrt_idx == 0) { // A * B = C, A' = C' * Bt if (opr.param().transposeA) { - grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true}); + grad = BatchedMatrixMul::make( + i1, og, {opr.param().transposeB, true}, opr.execution_policy()); } else { - grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB}); + grad = BatchedMatrixMul::make( + og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); } } else { mgb_assert(wrt_idx == 1); // A * B = C, B' = At * C' if (opr.param().transposeB) { - grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA}); + grad = BatchedMatrixMul::make( + og, i0, {true, opr.param().transposeA}, opr.execution_policy()); } else { - grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false}); + grad = BatchedMatrixMul::make( + i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); } } return grad.node(); diff --git a/src/opr/impl/dnn/pooling.cpp b/src/opr/impl/dnn/pooling.cpp index 1262f1a94..bec83feab 100644 --- a/src/opr/impl/dnn/pooling.cpp +++ b/src/opr/impl/dnn/pooling.cpp @@ -59,7 +59,8 @@ size_t PoolingForward::get_workspace_size_bytes( MGB_IMPL_OPR_GRAD(PoolingForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = PoolingBackward::make( - opr.input(0), opr.output(0), out_grad[0], opr.param()); + opr.input(0), opr.output(0), out_grad[0], opr.param(), + opr.execution_policy()); return grad.node(); } #endif diff --git a/src/opr/include/megbrain/opr/blas.h b/src/opr/include/megbrain/opr/blas.h index 2d89d2e2b..4badde15d 100644 --- a/src/opr/include/megbrain/opr/blas.h +++ b/src/opr/include/megbrain/opr/blas.h @@ -26,7 +26,7 @@ namespace opr { /*! * \brief matrix_mul(trans0(opr0), trans1(opr1)) */ -MGB_DEFINE_OPR_CLASS( +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( MatrixMul, intl::MegDNNOprWrapperFwd, public mixin::AlgoChooserHelper) // { public: @@ -57,7 +57,7 @@ private: /*! * \brief batched matrix multiplication on 3D inputs */ -MGB_DEFINE_OPR_CLASS( +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( BatchedMatrixMul, intl::MegDNNOprWrapperFwd, public mixin::AlgoChooserHelper) // { public: diff --git a/src/opr/include/megbrain/opr/dnn/pooling.h b/src/opr/include/megbrain/opr/dnn/pooling.h index 3b4efddef..a5f518e62 100644 --- a/src/opr/include/megbrain/opr/dnn/pooling.h +++ b/src/opr/include/megbrain/opr/dnn/pooling.h @@ -18,7 +18,7 @@ namespace mgb { namespace opr { -MGB_DEFINE_OPR_CLASS( +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( PoolingForward, intl::MegDNNOprWrapperFwd, public mixin::AlgoChooserHelper) // { public: @@ -37,7 +37,7 @@ public: }; using Pooling = PoolingForward; -MGB_DEFINE_OPR_CLASS( +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( PoolingBackward, intl::MegDNNOprWrapperBwd, public mixin::AlgoChooserHelper) // { public: diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h index 38d7485a3..40344d8d1 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h @@ -51,7 +51,7 @@ public: * Exception would be thrown if execution_policy() has been accessed, * since it would influence cache and many other decisions. */ - void set_execution_policy(const ExecutionPolicy& policy); + MGE_WIN_DECLSPEC_FUC void set_execution_policy(const ExecutionPolicy& policy); /*! * \brief register a hook to implement custom algo chooser -- GitLab