diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 21671e06e8f5b84ce9fb92ced6e5f7b783f15285..2895688b53ecda1fb8d03a06a4cecdff3748bd1c 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -152,9 +152,43 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto&& tshp_nd = inputs[1]; + auto slayout = src->layout(); + + TensorShape tshp; + cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); + if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(tshp[op_def.axis] == -1); + tshp[op_def.axis] = 1; + tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); + } + TensorLayout tlayout = slayout.reshape(tshp); + // memory forward + return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; +} + +void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + mgb_assert(inputs[0]->offset() == outputs[0]->offset()); + mgb_assert(inputs[0]->blob() == outputs[0]->blob()); +} + OP_TRAIT_REG(Reshape, Reshape) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // reshape