backward.yaml 2.0 KB
Newer Older
1
- backward_api : matmul_grad
2
  forward : matmul (const Tensor& x, const Tensor& y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out)
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
  args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
  output : Tensor(x_grad), Tensor(y_grad)
  infer_meta :
    func : MatmulGradInferMeta
  kernel :
    func : matmul_grad

- backward_api : scale_grad
  forward : scale (const Tensor& x, const Scalar& scale, float bias, bool bias_after_scale) -> Tensor(out)
  args : (const Tensor& out_grad, const Scalar& scale, float bias=0.0, bool bias_after_scale=true)
  output : Tensor(x_grad)
  invoke : scale(out_grad, scale, bias, bias_after_scale)

# TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future.
#
# - backward_api : matmul_double_grad
#   forward : matmul_grad (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x, bool transpose_y) -> tuple<Tensor, Tensor>(dx, dy)
#   args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, const Tensor& dx_grad, const Tensor& dy_grad, bool transpose_x, bool transpose_y)
#   output : tuple<Tensor, Tensor, Tensor>  // d2x, d2y, dout_grad
#   infer_meta :
#     func : MatmulDoubleGradInferMeta
#   kernel :
#     func : matmul_double_grad

# - backward_api : matmul_triple_grad
#   forward : matmul_double_grad (const Tensor& x, const Tensor& y, const Tensor& out_grad, const Tensor& dx_grad, const Tensor& dy_grad, bool transpose_x, bool transpose_y) -> tuple<Tensor, Tensor, Tensor>(d2x, d2y, dout_grad)
#   args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, const Tensor& dx_grad, const Tensor& dy_grad, const Tensor& d2x_grad, const Tensor& d2y_grad, const Tensor& dout_grad_grad, bool transpose_x, bool transpose_y)
#   output : tuple<Tensor, Tensor, Tensor, Tensor, Tensor>  // d3x, d3y, d2out_grad, ddx_grad, ddy_grad
#   infer_meta :
#     func : MatmulTripleGradInferMeta
#   kernel :
#     func : matmul_triple_grad