diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index eceb68815e76854e47e3fe90f275aa2d9f96faae..c9dcda1adb3f7bd481df3aa483b9bd3338e9e211 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape( auto out_dims = inputs_dims[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(), + platform::errors::InvalidArgument( + "The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, inputs_dims[0], i, inputs_dims[i])); for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { if (is_runtime) { out_dims[axis] += inputs_dims[i][j]; } else { - if (inputs_dims[i][j] == -1) { + if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { out_dims[axis] = -1; } else { out_dims[axis] += inputs_dims[i][j]; @@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape( "[%s], input[%d]'s shape = [%s].", j, i, inputs_dims[0], i, inputs_dims[i])); } + if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { + out_dims[j] = inputs_dims[i][j]; + } } } }