未验证 提交 625e30b7 编写于 作者: W wangxiaoning 提交者: GitHub

add gather_nd_comp_grad composite rule (#50966)

* comp gather_nd_grad

* fix

* test no cinn

* fix

* fix cinn
上级 792531b6
......@@ -786,7 +786,18 @@ void topk_grad(const Tensor& x,
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = put_along_axis<T>(zero_tensor, indices, out_grad, axis);
set_output<T>(x_grad_tmp, x_grad);
}
}
template <typename T>
void gather_nd_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = scatter_nd_add<T>(zero_tensor, index, out_grad);
set_output<T>(x_grad_tmp, x_grad);
}
}
......
......@@ -565,6 +565,7 @@
func : GatherNdGradInferMeta
kernel :
func : gather_nd_grad
composite : gather_nd_grad(x, index, out_grad, x_grad)
no_need_buffer : x
- backward_op : gelu_grad
......
......@@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS
test_elementwise_add_op
test_elementwise_sub_op
test_elementwise_div_op
test_elementwise_mul_op)
test_elementwise_mul_op
test_gather_nd_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -26,6 +26,7 @@ class TestGatherNdOpWithEmptyIndex(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
xnp = np.random.random((5, 20)).astype("float64")
self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
......@@ -37,12 +38,13 @@ class TestGatherNdOpWithEmptyIndex(OpTest):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithIndex1(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
xnp = np.random.random((5, 20)).astype("float64")
self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
......@@ -60,7 +62,9 @@ class TestGatherNdOpWithLowIndex(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
self.enable_cinn = False
xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
index = np.array([[1], [2]]).astype("int64")
......@@ -74,7 +78,7 @@ class TestGatherNdOpWithLowIndex(OpTest):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpIndex1(OpTest):
......@@ -82,6 +86,7 @@ class TestGatherNdOpIndex1(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
index = np.array([1, 2]).astype("int32")
......@@ -102,7 +107,9 @@ class TestGatherNdOpWithSameIndexAsX(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
self.enable_cinn = False
xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
index = np.array([[1, 1], [2, 1]]).astype("int64")
......@@ -113,7 +120,7 @@ class TestGatherNdOpWithSameIndexAsX(OpTest):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithHighRankSame(OpTest):
......@@ -121,6 +128,7 @@ class TestGatherNdOpWithHighRankSame(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype("float64")
......@@ -133,7 +141,7 @@ class TestGatherNdOpWithHighRankSame(OpTest):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithHighRankDiff(OpTest):
......@@ -141,6 +149,7 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
def setUp(self):
self.op_type = "gather_nd"
self.prim_op_type = "prim"
self.python_api = paddle.gather_nd
shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype("float64")
......@@ -154,7 +163,7 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=False)
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
# Test Python API
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册