From 6e1d0efd6d8b25dd1d32fd889b51f131b356598d Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Thu, 9 Jul 2020 09:48:03 +0800 Subject: [PATCH] fix concat shape error (#25414) (#25438) * fix concat shape error test=develop --- paddle/fluid/operators/concat_op.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index eceb68815e7..c9dcda1adb3 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape( auto out_dims = inputs_dims[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(), + platform::errors::InvalidArgument( + "The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, inputs_dims[0], i, inputs_dims[i])); for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { if (is_runtime) { out_dims[axis] += inputs_dims[i][j]; } else { - if (inputs_dims[i][j] == -1) { + if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { out_dims[axis] = -1; } else { out_dims[axis] += inputs_dims[i][j]; @@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape( "[%s], input[%d]'s shape = [%s].", j, i, inputs_dims[0], i, inputs_dims[i])); } + if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { + out_dims[j] = inputs_dims[i][j]; + } } } } -- GitLab