未验证 提交 5b3c7ee7 编写于 作者: W Weilong Wu 提交者: GitHub

support gather test on prim and cinn (#51376)

* support gather test on prim and cinn

* reset timeout for gather
上级 2d9e103e
...@@ -59,7 +59,11 @@ void gather_grad(const Tensor& x, ...@@ -59,7 +59,11 @@ void gather_grad(const Tensor& x,
std::vector<int> reverse_perm(tmp_perm); std::vector<int> reverse_perm(tmp_perm);
// make origin ranks // make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) { for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
reverse_perm[tmp_perm[i]] = i; if (tmp_perm[i] >= 0) {
reverse_perm[tmp_perm[i]] = i;
} else {
reverse_perm[tmp_perm[i] + tmp_perm.size()] = i;
}
} }
// transpose out_grad and zero grad to target rank. // transpose out_grad and zero grad to target rank.
......
...@@ -982,7 +982,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120) ...@@ -982,7 +982,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180) set_tests_properties(test_gather_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250) set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250)
set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120) set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120) set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120)
...@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS ...@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS
test_transpose_op test_transpose_op
test_reshape_op test_reshape_op
test_mean_op test_mean_op
test_unsqueeze2_op) test_unsqueeze2_op
test_gather_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -35,6 +35,7 @@ class TestGatherOp(OpTest): ...@@ -35,6 +35,7 @@ class TestGatherOp(OpTest):
self.op_type = "gather" self.op_type = "gather"
self.python_api = paddle.gather self.python_api = paddle.gather
self.config() self.config()
self.prim_op_type = "prim"
xnp = np.random.random(self.x_shape).astype(self.x_type) xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = { self.inputs = {
'X': xnp, 'X': xnp,
...@@ -46,7 +47,7 @@ class TestGatherOp(OpTest): ...@@ -46,7 +47,7 @@ class TestGatherOp(OpTest):
self.check_output(check_eager=True) self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
def config(self): def config(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册