提交 e1c7b22f 编写于 作者: M Megvii Engine Team

perf(ops): enable memory forward for reduce in special cases

GitOrigin-RevId: dd6e1664c50aabadc40bab6d01fb7fb31720bd8b
上级 cd60d268
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "../dnn_op_helper.h" #include "../dnn_op_helper.h"
...@@ -35,9 +36,48 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { ...@@ -35,9 +36,48 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Reduce::make(node->param()); return Reduce::make(node->param());
} }
bool memory_forward_success(
const OpDef& def,
SmallVector<TensorPtr> inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
auto shape_tensor = inputs[1]->get_value();
TensorShape shape;
cg::copy_tensor_value_to_shape(shape, shape_tensor.proxy_to_default_cpu());
if (shape.eq_shape(inputs[0]->shape())) {
return true;
}
}
return false;
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
if (memory_forward_success(def, inputs_tensors)) {
auto& src_desc = inputs_mems[0];
return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}}, {}};
}
return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}
void execute(const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (memory_forward_success(def, inputs)) {
return;
}
return proxy_graph_detail::execute(def, inputs, outputs, workspace);
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace reduce } // namespace reduce
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册