• L
    Lml/prim op pywrapper (#41813) · ebf4fe6e
    levi131 提交于
    * native commit for triple grad of sigmod
    
    * Updated unittests files
    
    * init functional jacobian api
    
    * Updated trible_test func
    
    * Updated gradient_checker & test_script
    
    * finish test with dtype float32
    
    * add float64 test case
    
    * polish code
    
    * use atol=1e-5 with dtype float64
    
    * fix for ci
    
    * set timeout for test_jacobian
    
    * fix dygraph grad to support high differential
    
    * polish API docstring
    
    * Updated gradient checker and some related files
    
    * fix double grad strip error for high differential
    
    * fix double grad strip error for high differential
    
    * Add Sigmoid triple grad tests
    
    * fix dygraph double grad dtype error when calling for high differential senario
    
    * Updated triple grad teses func
    
    * Use np.random to initialize ddx
    
    * Updated triple_grad_check func
    
    * add todo for gradient checker and refine some comments
    
    * remove additional code
    
    * add test for warnging in backward.py
    
    * format python code
    
    * support multi input in triple gradient checker
    
    * Add matmul triple grad kernel
    
    * Updated comments of TODO
    
    * Supported some special tests
    
    * Change code-format to follow CI std
    
    * Updated gradient_checker.py
    
    * Fix conflicts
    
    * Removed unnecessary printing log
    
    * Change code style to follow CI std
    
    * merge upstream
    
    * add priops.py
    
    * add_p
    
    * rm useless files
    
    * add sub_p mul_p div_p
    
    * add sqrt_p and tanh_p
    
    * add reshape_p
    
    * add broadcast_p
    
    * Add python primitive wrappers.
    
    * Jvp rules updated.
    
    * JVP rules done for all the 17 primops.
    
    * quick check and fixes.
    
    * add jvp(op, *args)
    
    * add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p
    
    * add split_p and concat_p
    
    * add gather_p and scatter_add_p
    
    * add slice_select_p and slice_assign_p
    
    * Add transpose rules.
    
    * add multi input check for add_p, sub_p, mul_p, div_p
    
    * update concat_p
    
    * Linearize and transpose in progress..
    
    * refine gather_p and scatter_add_p
    
    * updated.
    
    * update transpose.
    
    * refine slice_assign_p and slice_select_p
    
    * init commit for lower
    
    * Merged with primitive ops.
    
    * small update
    
    * add rules for orig2prim and prim2orig
    
    * add 9 test for prim ops
    
    * add more test and fix some bug
    
    * add more test
    
    * register proto
    
    * Adding primops test.
    
    * add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto
    
    * support multi input and multi output for split_p and concat_p
    
    * Test updated.
    
    * update
    
    * fix slice bug for slice_select_p and slice_assign_p
    
    * updated.
    
    * Ops updated.
    
    * Refactor and bug fixes.
    
    * updated.
    
    * finish orig2prim and prim2orig rules
    
    * dtype for axis attr should be long int
    
    * update dtype for axis attr int64_t
    
    * update for iscan CI
    
    * Update primx.
    
    * Refactor vars in primx.
    
    * update for lower transform
    
    * update primx.py
    
    * update
    
    * Fix linearize and transpose.
    
    * Update is_dot
    
    * Update is_dot
    
    * Update is_dot
    
    * add gradient aggregation, fix add_transpose.
    
    * pass first linearize+transpose test.
    
    * update test
    
    * add_prim_op_pywrapper
    
    * Add primops UT
    
    * Fix set_value and update
    
    * Fix code format and PR-CI-Coverage
    Co-authored-by: Nveyron95 <veyron_wu@163.com>
    Co-authored-by: NJiabin Yang <360788950@qq.com>
    Co-authored-by: NTongxin Bai <waffle.bai@gmail.com>
    Co-authored-by: N0x45f <wangzhen45@baidu.com>
    ebf4fe6e
primops.py 8.0 KB