perf(opdef/reshape): specialize Reshape

GitOrigin-RevId: 26d0e151ca89058782c554026907fd1ad3ec7340
上级 77309609
无相关合并请求
......@@ -152,9 +152,43 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout = slayout.reshape(tshp);
// memory forward
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}};
}
void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(inputs[0]->offset() == outputs[0]->offset());
mgb_assert(inputs[0]->blob() == outputs[0]->blob());
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // reshape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部