未验证 提交 dca3a099 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Enhance gather vjp (#50786)

* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* fix ut

* add gather composite
上级 5c9299e5
...@@ -543,6 +543,7 @@ ...@@ -543,6 +543,7 @@
kernel : kernel :
data_type: x data_type: x
func : gather_grad func : gather_grad
composite : gather_grad(x, index, out_grad, axis, overwrite)
no_need_buffer : x no_need_buffer : x
- backward_op : group_norm_grad - backward_op : group_norm_grad
......
...@@ -113,9 +113,15 @@ class TestGatherGradComp(unittest.TestCase): ...@@ -113,9 +113,15 @@ class TestGatherGradComp(unittest.TestCase):
def test_cinn(self): def test_cinn(self):
paddle.disable_static() paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False) use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed # TODO(jiabin): CINN will crashed in this case open it when fixed
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) use_cinn = False
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)
for i in range(len(dy_res)): for i in range(len(dy_res)):
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册