提交 4b27e861 编写于 作者: M Megvii Engine Team

fix(ops): implement from_op_node for reshape

GitOrigin-RevId: 4c994385041d94f34ab68029e5ac5f09f786d5a7
上级 4fb3d886
......@@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
namespace reshape {
auto make_from_op_node(const cg::OperatorNodeBase* node) {
auto& opr = node->cast_final_safe<opr::Reshape>();
return Reshape::make(opr.param(), std::vector<int32_t>());
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
......@@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.make_from_op_node(make_from_op_node)
.fallback();
} // namespace reshape
......
......@@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds;
for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor());
input_tensornds.push_back(inp->dev_tensor(false));
}
SmallVector<DeviceTensorND> output_tensornds = {
{CompNode::default_cpu(), dtype::Int32()}};
......@@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册