提交 4b27e861 编写于 作者: M Megvii Engine Team

fix(ops): implement from_op_node for reshape

GitOrigin-RevId: 4c994385041d94f34ab68029e5ac5f09f786d5a7
上级 4fb3d886
...@@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) ...@@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
namespace reshape { namespace reshape {
auto make_from_op_node(const cg::OperatorNodeBase* node) {
auto& opr = node->cast_final_safe<opr::Reshape>();
return Reshape::make(opr.param(), std::vector<int32_t>());
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def); auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
...@@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape) ...@@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint) .get_input_layout_constraint(get_input_layout_constraint)
.make_from_op_node(make_from_op_node)
.fallback(); .fallback();
} // namespace reshape } // namespace reshape
......
...@@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor( ...@@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds; SmallVector<DeviceTensorND> input_tensornds;
for (auto&& inp : inputs) { for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor()); input_tensornds.push_back(inp->dev_tensor(false));
} }
SmallVector<DeviceTensorND> output_tensornds = { SmallVector<DeviceTensorND> output_tensornds = {
{CompNode::default_cpu(), dtype::Int32()}}; {CompNode::default_cpu(), dtype::Int32()}};
...@@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor( ...@@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor(
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))}; return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
} }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册