提交 8fe865d8 编写于 作者: M Megvii Engine Team

feat(imperative/ops): add infer_output_attrs for Reshape

GitOrigin-RevId: 9150d7f84d1f0a4e50f5160213c660fdca904224
上级 267d6127
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace { namespace broadcast {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
node_->cast_final_safe<opr::Broadcast>(); node_->cast_final_safe<opr::Broadcast>();
...@@ -39,7 +39,7 @@ bool valid_broadcast(const TensorShape& src_shape, ...@@ -39,7 +39,7 @@ bool valid_broadcast(const TensorShape& src_shape,
if (src_ndim > tar_ndim) { if (src_ndim > tar_ndim) {
return false; return false;
} }
size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; size_t min_ndim = src_ndim;
for (size_t i = 0; i < min_ndim; ++i) { for (size_t i = 0; i < min_ndim; ++i) {
if (src_shape[src_ndim - i - 1] != 1 && if (src_shape[src_ndim - i - 1] != 1 &&
src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
...@@ -87,7 +87,70 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) ...@@ -87,7 +87,70 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback(); .fallback();
} // anonymous namespace } // broadcast
namespace reshape {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
}
if (src.layout.ndim == 0) {
return {{{out_layout, src.comp_node}}, false};
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(out_layout.shape[op.axis] == -1);
out_layout.shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0,
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems();
} else {
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(),
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
}
return {{{out_layout, src.comp_node}}, true};
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // reshape
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -548,19 +548,6 @@ OP_TRAIT_REG(Remap, Remap) ...@@ -548,19 +548,6 @@ OP_TRAIT_REG(Remap, Remap)
.fallback(); .fallback();
}} // remap }} // remap
namespace { namespace reshape {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reshape
namespace { namespace {
auto get_index( auto get_index(
const VarNodeArray& inputs, size_t vidx, const VarNodeArray& inputs, size_t vidx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册