diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index eda54f76b898cdf893347d31cadb86dea892a4ce..37f69426b62fedf8cbeca68105fb86fb4ea72eab 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -56,6 +56,9 @@ class ReshapeOp : public framework::OperatorWithKernel { static framework::DDim ValidateShape(const std::vector shape, const framework::DDim &in_dims) { const int64_t in_size = framework::product(in_dims); + auto in_dims_vec = framework::vectorize(in_dims); + bool all_positive = std::all_of(in_dims_vec.cbegin(), in_dims_vec.cend(), + [](int64_t i) { return i > 0; }); // only one dimension can be set to -1, whose size will be automatically // infered. const int64_t unk_dim_val = -1; @@ -88,7 +91,7 @@ class ReshapeOp : public framework::OperatorWithKernel { } if (unk_dim_idx != -1) { - if (in_size > 0) { + if (all_positive) { // in_size < 0 and is un-determinate in compile time, skip the check, // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // capacity = -24, in_size = -8, output_shape[0] = 0 diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0f4fe1b559e1e79bace82e13f0f8828b869d69b7..5b4f1efe479b12cb8ec390b8753d097764d70860 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4833,11 +4833,6 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): """ def __check_input(x, y): - if len(y.shape) > len(x.shape): - raise ValueError( - "Invalid inputs for matmul. " - "x's rank should be always greater than or equal to y'rank.") - x_shape = list(x.shape) y_shape = list(y.shape) if len(x_shape) == 1: @@ -4853,10 +4848,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): if x_shape[-1] != y_shape[-2]: raise ValueError("Invalid inputs for matmul.") - if len(y_shape) > 2: + if len(y_shape) > 2 and len(x_shape) > 2: for i, dim_x in enumerate(x_shape[:-2]): if dim_x != y_shape[i]: - raise ValueError("Invalid inputs for matmul.") + raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" % + (x.shape, y.shape)) __check_input(x, y)