From 0ef5183c65fa5246a125f501be4b89f4b55e2c81 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 11 Aug 2021 14:12:54 +0800 Subject: [PATCH] perf(opdef/reshape): specialize Reshape GitOrigin-RevId: 26d0e151ca89058782c554026907fd1ad3ec7340 --- imperative/src/impl/ops/broadcast.cpp | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 21671e06e..2895688b5 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -152,9 +152,43 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto&& tshp_nd = inputs[1]; + auto slayout = src->layout(); + + TensorShape tshp; + cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); + if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(tshp[op_def.axis] == -1); + tshp[op_def.axis] = 1; + tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); + } + TensorLayout tlayout = slayout.reshape(tshp); + // memory forward + return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; +} + +void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + mgb_assert(inputs[0]->offset() == outputs[0]->offset()); + mgb_assert(inputs[0]->blob() == outputs[0]->blob()); +} + OP_TRAIT_REG(Reshape, Reshape) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // reshape -- GitLab