未验证 提交 870c0837 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】 modify_yaml (#51436)

* modify_yaml

* delete default param

* add output for matmul_double_grad
上级 4a8b97ee
......@@ -43,7 +43,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(x, y, out_grad, axis)
composite : add_grad(x, y, out_grad, axis, x_grad, y_grad)
backward : add_double_grad
inplace : (out_grad -> x_grad)
......@@ -157,7 +157,7 @@
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
invoke : cast (out_grad, x.dtype())
composite: cast_grad(x, out_grad)
composite: cast_grad(x, out_grad, x_grad)
no_need_buffer : x
- backward_op : channel_shuffle_grad
......@@ -358,7 +358,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(x, y, out, out_grad, axis)
composite : divide_grad(x, y, out, out_grad, axis, x_grad, y_grad)
backward : divide_double_grad
- backward_op : dropout_grad
......@@ -400,7 +400,7 @@
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, y]
composite : elementwise_pow_grad(x, y, out_grad, axis)
composite : elementwise_pow_grad(x, y, out_grad, axis, x_grad, y_grad)
kernel :
func : elementwise_pow_grad
......@@ -439,7 +439,7 @@
func : expand_grad
no_need_buffer : x
backward : expand_double_grad
composite: expand_grad(x, out_grad, shape, x_grad_p)
composite: expand_grad(x, out_grad, shape, x_grad)
- backward_op : exponential__grad
forward : exponential_ (Tensor x, float lam) -> Tensor(out)
......@@ -514,7 +514,7 @@
kernel :
data_type: x
func : gather_grad
composite : gather_grad(x, index, out_grad, axis, overwrite)
composite : gather_grad(x, index, out_grad, axis, overwrite, x_grad)
no_need_buffer : x
- backward_op : group_norm_grad
......@@ -689,7 +689,7 @@
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x=false, transpose_y=false)
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x, transpose_y, x_grad, y_grad, grad_out_grad)
backward : matmul_triple_grad
optional : grad_x_grad, grad_y_grad
......@@ -1131,7 +1131,7 @@
param : [input]
kernel :
func : slice_grad
composite: slice_grad(input, out_grad, axes, starts, ends, infer_flags, decrease_axis)
composite: slice_grad(input, out_grad, axes, starts, ends, infer_flags, decrease_axis, input_grad)
backward : slice_double_grad
no_need_buffer : input
......@@ -1213,7 +1213,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(x, y, out_grad, axis)
composite : subtract_grad(x, y, out_grad, axis, x_grad, y_grad)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)
......@@ -1303,7 +1303,7 @@
kernel :
func : transpose_grad
backward : transpose_double_grad
composite: transpose_grad(out_grad, perm)
composite: transpose_grad(out_grad, perm, x_grad)
- backward_op : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册