From 4b27e861f4701d0caf084a2f341917ea724293e8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 Apr 2022 12:18:51 +0800 Subject: [PATCH] fix(ops): implement from_op_node for reshape GitOrigin-RevId: 4c994385041d94f34ab68029e5ac5f09f786d5a7 --- imperative/src/impl/ops/broadcast.cpp | 6 ++++++ imperative/src/impl/ops/tensor_manip.cpp | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index ba9100fb2..db3a1506e 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) namespace reshape { +auto make_from_op_node(const cg::OperatorNodeBase* node) { + auto& opr = node->cast_final_safe(); + return Reshape::make(opr.param(), std::vector()); +} + auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); @@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_physical_tensor(apply_on_physical_tensor) .get_input_layout_constraint(get_input_layout_constraint) + .make_from_op_node(make_from_op_node) .fallback(); } // namespace reshape diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 403356389..00eca1e76 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor( const OpDef& def, const SmallVector& inputs) { SmallVector input_tensornds; for (auto&& inp : inputs) { - input_tensornds.push_back(inp->dev_tensor()); + input_tensornds.push_back(inp->dev_tensor(false)); } SmallVector output_tensornds = { {CompNode::default_cpu(), dtype::Int32()}}; @@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor( SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))}; + return {Tensor::make(get_var_shape_host_tensor(def, inputs))}; } std::tuple, bool> infer_output_attrs_fallible( -- GitLab