未验证 提交 6e1d0efd 编写于 作者: G GaoWei8 提交者: GitHub

fix concat shape error (#25414) (#25438)

* fix concat shape error
test=develop
上级 5c84eac8
...@@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape( ...@@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape(
auto out_dims = inputs_dims[0]; auto out_dims = inputs_dims[0];
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
platform::errors::InvalidArgument(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i]));
for (size_t j = 0; j < in_zero_dims_size; j++) { for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) { if (j == axis) {
if (is_runtime) { if (is_runtime) {
out_dims[axis] += inputs_dims[i][j]; out_dims[axis] += inputs_dims[i][j];
} else { } else {
if (inputs_dims[i][j] == -1) { if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
out_dims[axis] = -1; out_dims[axis] = -1;
} else { } else {
out_dims[axis] += inputs_dims[i][j]; out_dims[axis] += inputs_dims[i][j];
...@@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape( ...@@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape(
"[%s], input[%d]'s shape = [%s].", "[%s], input[%d]'s shape = [%s].",
j, i, inputs_dims[0], i, inputs_dims[i])); j, i, inputs_dims[0], i, inputs_dims[i]));
} }
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
out_dims[j] = inputs_dims[i][j];
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册