未验证 提交 5b3c7ee7 编写于 作者: W Weilong Wu 提交者: GitHub

support gather test on prim and cinn (#51376)

* support gather test on prim and cinn

* reset timeout for gather
上级 2d9e103e
......@@ -59,7 +59,11 @@ void gather_grad(const Tensor& x,
std::vector<int> reverse_perm(tmp_perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
if (tmp_perm[i] >= 0) {
reverse_perm[tmp_perm[i]] = i;
} else {
reverse_perm[tmp_perm[i] + tmp_perm.size()] = i;
}
}
// transpose out_grad and zero grad to target rank.
......
......@@ -982,7 +982,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250)
set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120)
......@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS
test_transpose_op
test_reshape_op
test_mean_op
test_unsqueeze2_op)
test_unsqueeze2_op
test_gather_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -35,6 +35,7 @@ class TestGatherOp(OpTest):
self.op_type = "gather"
self.python_api = paddle.gather
self.config()
self.prim_op_type = "prim"
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
......@@ -46,7 +47,7 @@ class TestGatherOp(OpTest):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
def config(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册