diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 5bc640af25e5a02c9093d85bab018fa7e9e416ad..b247addcc8f6747dd51b77e3e24ff5e164480243 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -81,10 +81,33 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs) { + auto& input = inputs[0]; + TensorShape target_shape; + cg::copy_tensor_value_to_shape( + target_shape, inputs[1]->get_value().proxy_to_default_cpu()); + TensorPtr output = Tensor::make( + TensorLayout(target_shape, input->dtype()), input->comp_node()); + if (output->layout().is_empty()) { + return {output}; + } + if (input->shape().eq_shape(output->shape())) { + mgb_assert(input->layout().eq_layout(output->layout())); + output->dev_tensor().copy_from_fixlayout(input->dev_tensor()); + } else { + TensorLayout input_layout = input->layout().broadcast(output->shape()); + output->dev_tensor().copy_from_fixlayout( + input->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout))); + } + return {output}; +} + OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace broadcast @@ -147,9 +170,31 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto&& tshp_nd = inputs[1]; + auto slayout = src->layout(); + + TensorShape tshp; + cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); + if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(tshp[op_def.axis] == -1); + tshp[op_def.axis] = 1; + tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); + } + TensorLayout tlayout = slayout.reshape(tshp); + // memory forward + return {Tensor::make(src->blob(), 0, tlayout)}; +} + OP_TRAIT_REG(Reshape, Reshape) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace reshape diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index e46f292705943408fc90c4bf6a6043603643e2fa..030709ee2379d8087bdda3f9b58387fe15c6af54 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -50,9 +50,18 @@ bool memory_forward_success(const OpDef& def, SmallVector inputs) { return false; } +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs) { + if (memory_forward_success(def, inputs)) { + return {Tensor::make(inputs[0]->blob(), 0, inputs[0]->layout())}; + } + return proxy_graph_detail::apply_on_physical_tensor(def, inputs); +} + OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace reduce } // namespace