From 0a75ed6f5bcad53b3622ee89d48a5613cc4ef396 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 14 Dec 2017 05:06:19 +0000 Subject: [PATCH] Add unit test for dimension inference in reshape_op --- paddle/operators/reshape_op.cc | 14 +++++++++----- python/paddle/v2/fluid/tests/test_reshape_op.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 306dfa806..164f3104e 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel { if (shape[i] == -1) { neg_dims_idx.push_back(i); PADDLE_ENFORCE(neg_dims_idx.size() <= 1, - "Only one dimension of Attr(shape) can be -1."); + "Only one dimension of Attr(shape) can be unknown."); } } - // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int64_t in_size = framework::product(x_dims); if (neg_dims_idx.size() == 1) { - shape[neg_dims_idx[0]] = in_size / (-capacity); - PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0, - "The size of Input(X) mismatches with Attr(shape)."); + // dim infer + shape[neg_dims_idx[0]] = in_size / (-capacity); + // recalculate capacity + capacity = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); } + // capacity check + PADDLE_ENFORCE(capacity == in_size, + "The size of Input(X) mismatches with Attr(shape)."); // resize output std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), diff --git a/python/paddle/v2/fluid/tests/test_reshape_op.py b/python/paddle/v2/fluid/tests/test_reshape_op.py index 16bb6bb2a..18ee3aece 100644 --- a/python/paddle/v2/fluid/tests/test_reshape_op.py +++ b/python/paddle/v2/fluid/tests/test_reshape_op.py @@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpDimInfer(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5]} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main() -- GitLab