diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 306dfa8069bce1009455d8720155331fe296610e..164f3104eb69caf4f93542de39c3d089694ffe03 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel { if (shape[i] == -1) { neg_dims_idx.push_back(i); PADDLE_ENFORCE(neg_dims_idx.size() <= 1, - "Only one dimension of Attr(shape) can be -1."); + "Only one dimension of Attr(shape) can be unknown."); } } - // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int64_t in_size = framework::product(x_dims); if (neg_dims_idx.size() == 1) { - shape[neg_dims_idx[0]] = in_size / (-capacity); - PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0, - "The size of Input(X) mismatches with Attr(shape)."); + // dim infer + shape[neg_dims_idx[0]] = in_size / (-capacity); + // recalculate capacity + capacity = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); } + // capacity check + PADDLE_ENFORCE(capacity == in_size, + "The size of Input(X) mismatches with Attr(shape)."); // resize output std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), diff --git a/python/paddle/v2/fluid/tests/test_reshape_op.py b/python/paddle/v2/fluid/tests/test_reshape_op.py index 16bb6bb2af67f7d32a2fafc1cb37412084ec0829..18ee3aece656276fec9671df9baf298b7fd3c9b1 100644 --- a/python/paddle/v2/fluid/tests/test_reshape_op.py +++ b/python/paddle/v2/fluid/tests/test_reshape_op.py @@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpDimInfer(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5]} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main()