• C
    [PRIM][IR]support add vjp (#56163) · 84482da8
    Charles-hit 提交于
    * [prim][newir] add basic framework for primitive
    
    * support desctensor in new ir
    
    * add vjp interface
    
    * support vjp in new ir
    
    * support vjp in new ir
    
    * polish vjp interface
    
    * fix stop_gradients set
    
    * fix vjp dispatch
    
    * add comment
    
    * add vjp test for new ir
    
    * add test for tanh vjp
    
    * [prim][newir] add basic framework for primitive
    
    * support desctensor in new ir
    
    * support vjp in new ir
    
    * support vjp in new ir
    
    * polish vjp interface
    
    * fix stop_gradients set
    
    * fix vjp dispatch
    
    * add comment
    
    * add vjp test for new ir
    
    * add test for tanh vjp
    
    * add eager and static backend for warp lower level api
    
    * support call_vjp pybind
    
    * polish code and add test for vjp
    
    * remove useless code
    
    * polish code
    
    * remove useless code
    
    * support mean vjp
    
    * add test for mean vjp and support has_vjp function
    
    * fix call_vjp
    
    * polish code
    
    * add primitive ops set for backend
    
    * add vjp test for tanh_
    
    * fix inference CI
    
    * fix inference ci
    
    * modify fluid cmake
    
    * remove useless deps
    
    * add cmake
    
    * fix comment
    
    * fix test
    
    * polish code
    
    * modify backward stop_gradients
    
    * modify static_backend.cc
    
    * support add and add_inplace vjp
    
    * remove useless code
    
    * remove useless code
    
    * remove cout
    
    * remove cout
    
    * fix add_grad
    
    * fix add test exe
    
    ---------
    Co-authored-by: Ncxxly <chenxx_id@163.com>
    Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
    84482da8
vjp.cc 5.3 KB