From e1c7b22ff000d0813d02a7f8ff2493552bf3cb40 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 11 Aug 2021 20:16:47 +0800 Subject: [PATCH] perf(ops): enable memory forward for reduce in special cases GitOrigin-RevId: dd6e1664c50aabadc40bab6d01fb7fb31720bd8b --- imperative/src/impl/ops/reduce.cpp | 40 ++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index b9aa65203..461aa45c4 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -11,6 +11,7 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/basic_arith.h" +#include "megbrain/imperative/proxy_graph_detail.h" #include "../op_trait.h" #include "../dnn_op_helper.h" @@ -35,9 +36,48 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { return Reduce::make(node->param()); } +bool memory_forward_success( + const OpDef& def, + SmallVector inputs) { + auto&& reduce = static_cast(def); + if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) { + auto shape_tensor = inputs[1]->get_value(); + TensorShape shape; + cg::copy_tensor_value_to_shape(shape, shape_tensor.proxy_to_default_cpu()); + if (shape.eq_shape(inputs[0]->shape())) { + return true; + } + } + return false; +} + +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + if (memory_forward_success(def, inputs_tensors)) { + auto& src_desc = inputs_mems[0]; + return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}}, {}}; + } + return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems); +} + + +void execute(const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + if (memory_forward_success(def, inputs)) { + return; + } + return proxy_graph_detail::execute(def, inputs, outputs, workspace); +} + OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // namespace reduce } // namespace -- GitLab