提交 0a75ed6f 编写于 作者: Y Yibing Liu

Add unit test for dimension inference in reshape_op

上级 685d1e3b
...@@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (shape[i] == -1) { if (shape[i] == -1) {
neg_dims_idx.push_back(i); neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1, PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be -1."); "Only one dimension of Attr(shape) can be unknown.");
} }
} }
// capacity check
int64_t capacity = int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims); int64_t in_size = framework::product(x_dims);
if (neg_dims_idx.size() == 1) { if (neg_dims_idx.size() == 1) {
shape[neg_dims_idx[0]] = in_size / (-capacity); // dim infer
PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0, shape[neg_dims_idx[0]] = in_size / (-capacity);
"The size of Input(X) mismatches with Attr(shape)."); // recalculate capacity
capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
} }
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output // resize output
std::vector<int64_t> shape_int64(shape.size(), 0); std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), std::transform(shape.begin(), shape.end(), shape_int64.begin(),
......
...@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): ...@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册