• Z
    add topk prim backward (#50679) · 296b3ff0
    zqw_1997 提交于
    * tmp gather vjp
    
    * support gather
    
    * remove useless code
    
    * fix compiling error
    
    * fix ut
    
    * add eager test
    
    * add eager test
    
    * add seed
    
    * small change
    
    * fix cpu error
    
    * fix transpose op compat
    
    * remove tensor index case
    
    * fix prim_cinn
    
    * small commit
    
    * add cumsum prim backward
    
    * small commit
    
    * skip aixs=None test case
    
    * fix op generante eror
    
    * fix static test error
    
    * remove unused code
    
    * fix static test error
    
    * small commit
    
    * skip cpu float16 test case
    
    * skip eager cpu cumsum float16 test case
    
    * add eager and static UT
    
    * fix ut
    
    * add composite backward rule
    
    * fix error
    
    * fix type error and format error
    
    * add try cpu+float16 test
    
    * fix test bugs
    
    * remove test for cpu+float16 and make y[0] be the grad arg
    
    * add cinn test
    
    * fix UT
    
    * fix the wrong dim of v in test cases
    
    * change y[0] to y[1] for grad in UT
    
    * reshape flatten out
    
    * Disable cinn single test
    
    * use scatter_nd_add
    
    * modify the reshape part of topk_grad
    
    * delete useless build file
    
    * to make the syntax right
    
    * modify bug
    
    * try use of put_along_axis
    
    * remove cinn test
    
    * reformat todo
    
    * add silu composite rule
    
    * fix code style.
    
    * add cinn test
    
    * fix composite grad maker code gen
    
    * add prim in cumsum op test
    
    * remove old test
    
    * fix typro
    
    * pass the static test
    
    * fix typro
    
    * modify optest and delete old test files
    
    * remove normal test_top_k_op test
    
    * fix typro
    
    * pass axis=None test case
    
    * buffer comment
    
    * for debug
    
    * add silu fp16 unit test.
    
    * add static guard
    
    * remove forward prim test
    
    * remove same name axis
    
    * modify the test_top_v2_op.py to pass all local tests
    
    * delete the useless testcase
    
    * fix mistake
    
    * add more testcases to test dtype16 and dtype32
    
    ---------
    Co-authored-by: NJiabinYang <360788950@qq.com>
    Co-authored-by: NGGBond8488 <857631483@qq.com>
    Co-authored-by: Nzxcd <228587199@qq.com>
    Co-authored-by: NCharles-hit <wanghao107@baidu.com>
    296b3ff0
api.yaml 273 字节