提交 fe1680b3 编写于 作者: M Megvii Engine Team

fix(imperative/ops): improve infer_output_attrs for broadcast

GitOrigin-RevId: 6b7ed5576947a7614b7ad00bbfcbc1d5e520a33a
上级 b1806b84
...@@ -58,10 +58,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -58,10 +58,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0]; auto&& src = inputs[0];
auto&& tshp = inputs[1]; auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout; TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) { if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0; out_shape.ndim = 0;
return {{{out_layout, src.comp_node}}, false}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
} }
mgb_assert( mgb_assert(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
...@@ -69,17 +69,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -69,17 +69,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim); tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0]; size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim; out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>(); auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) { for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i]; out_shape[i] = ptr[i];
} }
mgb_assert(valid_broadcast(src.layout, out_layout), mgb_assert(valid_broadcast(src.layout, out_shape),
"the input shape %s can not be broadcasted to target shape %s", "the input shape %s can not be broadcasted to target shape %s",
src.layout.TensorShape::to_string().c_str(), src.layout.to_string().c_str(),
out_layout.TensorShape::to_string().c_str()); out_shape.to_string().c_str());
return {{{out_layout, src.comp_node}}, true}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
} }
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
...@@ -108,10 +108,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -108,10 +108,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0]; auto&& src = inputs[0];
auto&& tshp = inputs[1]; auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout; TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) { if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0; out_shape.ndim = 0;
return {{{out_layout, src.comp_node}}, false}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
} }
mgb_assert( mgb_assert(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
...@@ -119,31 +119,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -119,31 +119,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim); tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0]; size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim; out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>(); auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) { for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i]; out_shape[i] = ptr[i];
} }
if (src.layout.ndim == 0) { if (src.layout.ndim == 0) {
return {{{out_layout, src.comp_node}}, false}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
} }
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(out_layout.shape[op.axis] == -1); mgb_assert(out_shape[op.axis] == -1);
out_layout.shape[op.axis] = 1; out_shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0, mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
"can not reshape from %s to %s", "can not reshape from %s to %s",
src.layout.to_string().c_str(), src.layout.to_string().c_str(),
out_layout.to_string().c_str()); out_shape.to_string().c_str());
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems(); out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
} else { } else {
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(), mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(),
"can not reshape from %s to %s", "can not reshape from %s to %s",
src.layout.to_string().c_str(), src.layout.to_string().c_str(),
out_layout.to_string().c_str()); out_shape.to_string().c_str());
} }
return {{{out_layout, src.comp_node}}, true}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
} }
OP_TRAIT_REG(Reshape, Reshape) OP_TRAIT_REG(Reshape, Reshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册