提交 93ceb80a 编写于 作者: M Megvii Engine Team

refactor(imperative): fix broadcast,reshape,reduce

GitOrigin-RevId: ee3dc1487ddfab276e0a8e39801412f19efeaa96
上级 d919aaeb
......@@ -81,10 +81,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto& input = inputs[0];
TensorShape target_shape;
cg::copy_tensor_value_to_shape(
target_shape, inputs[1]->get_value().proxy_to_default_cpu());
TensorPtr output = Tensor::make(
TensorLayout(target_shape, input->dtype()), input->comp_node());
if (output->layout().is_empty()) {
return {output};
}
if (input->shape().eq_shape(output->shape())) {
mgb_assert(input->layout().eq_layout(output->layout()));
output->dev_tensor().copy_from_fixlayout(input->dev_tensor());
} else {
TensorLayout input_layout = input->layout().broadcast(output->shape());
output->dev_tensor().copy_from_fixlayout(
input->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout)));
}
return {output};
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace broadcast
......@@ -147,9 +170,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout = slayout.reshape(tshp);
// memory forward
return {Tensor::make(src->blob(), 0, tlayout)};
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace reshape
......
......@@ -50,9 +50,18 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
return false;
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
if (memory_forward_success(def, inputs)) {
return {Tensor::make(inputs[0]->blob(), 0, inputs[0]->layout())};
}
return proxy_graph_detail::apply_on_physical_tensor(def, inputs);
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace reduce
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册