提交 0a75ed6f 编写于 作者: Y Yibing Liu

Add unit test for dimension inference in reshape_op

上级 685d1e3b
......@@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (shape[i] == -1) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be -1.");
"Only one dimension of Attr(shape) can be unknown.");
}
}
// capacity check
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
if (neg_dims_idx.size() == 1) {
shape[neg_dims_idx[0]] = in_size / (-capacity);
PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0,
"The size of Input(X) mismatches with Attr(shape).");
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
......
......@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册