未验证 提交 296b3ff0 编写于 作者: Z zqw_1997 提交者: GitHub

add topk prim backward (#50679)

* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* small change

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* small commit

* add cumsum prim backward

* small commit

* skip aixs=None test case

* fix op generante eror

* fix static test error

* remove unused code

* fix static test error

* small commit

* skip cpu float16 test case

* skip eager cpu cumsum float16 test case

* add eager and static UT

* fix ut

* add composite backward rule

* fix error

* fix type error and format error

* add try cpu+float16 test

* fix test bugs

* remove test for cpu+float16 and make y[0] be the grad arg

* add cinn test

* fix UT

* fix the wrong dim of v in test cases

* change y[0] to y[1] for grad in UT

* reshape flatten out

* Disable cinn single test

* use scatter_nd_add

* modify the reshape part of topk_grad

* delete useless build file

* to make the syntax right

* modify bug

* try use of put_along_axis

* remove cinn test

* reformat todo

* add silu composite rule

* fix code style.

* add cinn test

* fix composite grad maker code gen

* add prim in cumsum op test

* remove old test

* fix typro

* pass the static test

* fix typro

* modify optest and delete old test files

* remove normal test_top_k_op test

* fix typro

* pass axis=None test case

* buffer comment

* for debug

* add silu fp16 unit test.

* add static guard

* remove forward prim test

* remove same name axis

* modify the test_top_v2_op.py to pass all local tests

* delete the useless testcase

* fix mistake

* add more testcases to test dtype16 and dtype32

---------
Co-authored-by: NJiabinYang <360788950@qq.com>
Co-authored-by: NGGBond8488 <857631483@qq.com>
Co-authored-by: Nzxcd <228587199@qq.com>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
上级 e152e891
......@@ -26,3 +26,4 @@
- transpose
- pad
- cumsum
- put_along_axis
......@@ -744,5 +744,22 @@ void cumsum_grad(const Tensor& x,
}
}
template <typename T>
void topk_grad(const Tensor& x,
const Tensor& indices,
const Tensor& out_grad,
const Scalar& k,
const int& axis,
const bool& largest,
const bool& sorted,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = put_along_axis<T>(zero_tensor, indices, out_grad, axis);
set_output<T>(x_grad_tmp, x_grad);
}
}
} // namespace prim
} // namespace paddle
......@@ -1524,6 +1524,7 @@
kernel :
func : topk_grad
data_type : out_grad
composite : topk_grad(x, indices, out_grad, k, axis, largest, sorted, x_grad)
- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
......
......@@ -1209,7 +1209,8 @@ set(TEST_CINN_OPS
test_slice_op
test_activation_op
test_full_like_op
test_fill_any_like_op)
test_fill_any_like_op
test_top_k_v2_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -45,6 +45,7 @@ class TestTopkOp(OpTest):
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 20)
......@@ -60,7 +61,7 @@ class TestTopkOp(OpTest):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(set(['X']), 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
class TestTopkOp1(TestTopkOp):
......@@ -77,7 +78,7 @@ class TestTopkOp2(TestTopkOp):
self.largest = False
class TestTopkOp3(OpTest):
class TestTopkOp3(TestTopkOp):
def init_args(self):
self.k = 6
self.axis = 1
......@@ -85,6 +86,7 @@ class TestTopkOp3(OpTest):
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
......@@ -105,6 +107,7 @@ class TestTopkOp4(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
......@@ -125,6 +128,7 @@ class TestTopkOp5(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
......@@ -137,17 +141,39 @@ class TestTopkOp5(TestTopkOp):
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp6(OpTest):
class TestTopkOp6(TestTopkOp):
def init_args(self):
self.k = 100
self.k = 3
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(80, 16384)
self.dtype = np.float32
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest
)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp7(TestTopkOp):
def init_args(self):
self.k = 10
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float16
self.input_data = np.random.rand(10, 20, 10)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册