diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 7fd33bf662a1d0b7b6fa4e772bdadbf34b2f4fdd..d82d828747c0c822195b699359b8e62d1cf7e92d 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel { auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); auto x_dims = ctx->GetInputDim("X"); - // TODO(qiao) change batch_size - for (size_t i = 1; i < shape.size(); ++i) { - PADDLE_ENFORCE(shape[i] > 0, - "Each dimension of Attr(shape) " - "must be positive except the first one."); - } - if (shape[0] < 0) { - shape[0] = x_dims[0]; + + std::vector neg_dims_idx; + // set some dimension to -1 if it is unknown + const int unknown_size = -1; + for (size_t i = 0; i < shape.size(); ++i) { + PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size, + "Each dimension of Attr(shape) must be positive or %d.", + unknown_size); + if (shape[i] == unknown_size) { + neg_dims_idx.push_back(i); + PADDLE_ENFORCE(neg_dims_idx.size() <= 1, + "Only one dimension of Attr(shape) can be unknown."); + } } - // capacity check + int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int64_t in_size = framework::product(x_dims); - PADDLE_ENFORCE_EQ(capacity, in_size, - "The size of Input(X) mismatches with Attr(shape)."); + if (neg_dims_idx.size() == 1) { + // dim infer + shape[neg_dims_idx[0]] = in_size / (-capacity); + // recalculate capacity + capacity = shape[neg_dims_idx[0]] * (-capacity); + } + // capacity check + PADDLE_ENFORCE(capacity == in_size, + "The size of Input(X) mismatches with Attr(shape)."); // resize output std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), @@ -88,6 +100,9 @@ the tensor X into a 2-D tensor: [[1, 2, 3, 4]] +One dimension in the target shape can be set -1, representing that its +size is unknown. In this case, the real dimension will be infered from +the original shape of Input(X) and other dimensions in the target shape. )DOC"); } }; diff --git a/python/paddle/v2/fluid/tests/test_reshape_op.py b/python/paddle/v2/fluid/tests/test_reshape_op.py index 16bb6bb2af67f7d32a2fafc1cb37412084ec0829..18ee3aece656276fec9671df9baf298b7fd3c9b1 100644 --- a/python/paddle/v2/fluid/tests/test_reshape_op.py +++ b/python/paddle/v2/fluid/tests/test_reshape_op.py @@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpDimInfer(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5]} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main()