提交 fe1680b3 编写于 作者: M Megvii Engine Team

fix(imperative/ops): improve infer_output_attrs for broadcast

GitOrigin-RevId: 6b7ed5576947a7614b7ad00bbfcbc1d5e520a33a
上级 b1806b84
......@@ -58,10 +58,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout;
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, false};
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
......@@ -69,17 +69,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
out_shape[i] = ptr[i];
}
mgb_assert(valid_broadcast(src.layout, out_layout),
mgb_assert(valid_broadcast(src.layout, out_shape),
"the input shape %s can not be broadcasted to target shape %s",
src.layout.TensorShape::to_string().c_str(),
out_layout.TensorShape::to_string().c_str());
src.layout.to_string().c_str(),
out_shape.to_string().c_str());
return {{{out_layout, src.comp_node}}, true};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
......@@ -108,10 +108,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout;
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, false};
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
......@@ -119,31 +119,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
out_shape[i] = ptr[i];
}
if (src.layout.ndim == 0) {
return {{{out_layout, src.comp_node}}, false};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(out_layout.shape[op.axis] == -1);
out_layout.shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0,
mgb_assert(out_shape[op.axis] == -1);
out_shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems();
out_shape.to_string().c_str());
out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
} else {
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(),
mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(),
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
out_shape.to_string().c_str());
}
return {{{out_layout, src.comp_node}}, true};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
OP_TRAIT_REG(Reshape, Reshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册