diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 39bf2118d603881531bf583ae468e8dc9b8bd181..306dfa8069bce1009455d8720155331fe296610e 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -34,21 +34,27 @@ class ReshapeOp : public framework::OperatorWithKernel { auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); auto x_dims = ctx->GetInputDim("X"); - // TODO(qiao) change batch_size - for (size_t i = 1; i < shape.size(); ++i) { - PADDLE_ENFORCE(shape[i] > 0, - "Each dimension of Attr(shape) " - "must be positive except the first one."); - } - if (shape[0] < 0) { - shape[0] = x_dims[0]; + + std::vector neg_dims_idx; + for (size_t i = 0; i < shape.size(); ++i) { + PADDLE_ENFORCE(shape[i] > 0 || shape[i] == -1, + "Each dimension of Attr(shape) must be positive or -1."); + if (shape[i] == -1) { + neg_dims_idx.push_back(i); + PADDLE_ENFORCE(neg_dims_idx.size() <= 1, + "Only one dimension of Attr(shape) can be -1."); + } } + // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int64_t in_size = framework::product(x_dims); - PADDLE_ENFORCE_EQ(capacity, in_size, + if (neg_dims_idx.size() == 1) { + shape[neg_dims_idx[0]] = in_size / (-capacity); + PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0, "The size of Input(X) mismatches with Attr(shape)."); + } // resize output std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), @@ -88,6 +94,9 @@ the tensor X into a 1-D tensor: [1, 2, 3, 4] +One dimension in the target shape can be set -1, and the real dimension +will be infered from the original shape of Input(X) and other +dimensions in the target shape. )DOC"); } };