未验证 提交 25726b94 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[prim] modify multiply_double_grad composite rules to solve paddlescience...

[prim] modify multiply_double_grad composite rules to solve paddlescience poiseuille_flow.py problem (#54908)

* modify eular_beam

* modify matmul infermeta

* add test

* modify timeout

* modify mutiply_double nullptr

* modify tanh_triple_gradnode create
上级 19345fa7
...@@ -412,7 +412,8 @@ void multiply_double_grad(const Tensor& x, ...@@ -412,7 +412,8 @@ void multiply_double_grad(const Tensor& x,
} }
} else { } else {
x_grad = nullptr; auto dx = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
set_output<T>(dx, x_grad);
} }
} }
if (y_grad) { if (y_grad) {
...@@ -433,22 +434,22 @@ void multiply_double_grad(const Tensor& x, ...@@ -433,22 +434,22 @@ void multiply_double_grad(const Tensor& x,
set_output<T>(dy, y_grad); set_output<T>(dy, y_grad);
} }
} else { } else {
y_grad = nullptr; auto dy = full<T>(phi::vectorize(y.dims()), 0.0, y.dtype());
set_output<T>(dy, y_grad);
} }
} }
if (grad_out_grad) { if (grad_out_grad) {
Tensor ddout;
if (grad_x_grad && grad_y_grad) { if (grad_x_grad && grad_y_grad) {
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else if (grad_x_grad) { } else if (grad_x_grad) {
auto ddout = grad_x_grad.get() * y; ddout = grad_x_grad.get() * y;
set_output<T>(ddout, grad_out_grad);
} else if (grad_y_grad) { } else if (grad_y_grad) {
auto ddout = grad_y_grad.get() * x; ddout = grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else { } else {
grad_out_grad = nullptr; ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, grad_out.dtype());
} }
set_output<T>(ddout, grad_out_grad);
} }
} }
...@@ -461,10 +462,10 @@ void add_double_grad(const Tensor& y, ...@@ -461,10 +462,10 @@ void add_double_grad(const Tensor& y,
Tensor* grad_out_grad) { Tensor* grad_out_grad) {
if (grad_out_grad) { if (grad_out_grad) {
// ddout = ddx + ddy // ddout = ddx + ddy
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (!grad_x_grad && !grad_y_grad) { if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr; set_output<T>(ddout, grad_out_grad);
} else { } else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) { if (grad_x_grad) {
ddout = ddout + grad_x_grad.get(); ddout = ddout + grad_x_grad.get();
} }
......
...@@ -2185,7 +2185,6 @@ ...@@ -2185,7 +2185,6 @@
func : tanh_double_grad func : tanh_double_grad
composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad)
inplace : (grad_x_grad -> grad_out_grad) inplace : (grad_x_grad -> grad_out_grad)
backward : tanh_triple_grad
- backward_op : tanh_grad - backward_op : tanh_grad
forward : tanh (Tensor x) -> Tensor(out) forward : tanh (Tensor x) -> Tensor(out)
...@@ -2211,18 +2210,6 @@ ...@@ -2211,18 +2210,6 @@
func : tanh_shrink_grad func : tanh_shrink_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : tanh_triple_grad
forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad)
args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad)
output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, out, grad_x_grad_forward]
kernel :
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad
- backward_op : temporal_shift_grad - backward_op : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW") -> Tensor(out) forward : temporal_shift(Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW") -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format) args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format)
......
...@@ -2573,7 +2573,7 @@ ...@@ -2573,7 +2573,7 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : tanh - op : tanh
backward : tanh_grad, tanh_double_grad (tanh_grad_grad), tanh_triple_grad backward : tanh_grad, tanh_double_grad (tanh_grad_grad)
inputs : inputs :
x : X x : X
outputs : outputs :
......
...@@ -2489,9 +2489,11 @@ def calc_gradient_helper( ...@@ -2489,9 +2489,11 @@ def calc_gradient_helper(
block, targets, inputs, block_no_grad_set, op_path_dict block, targets, inputs, block_no_grad_set, op_path_dict
) )
# only for composite to add grad_op input, # only for composite to add grad_var of the last forward op
# tmp_targets includes targets and other outputs # who has more than one output, but targets only has one,
# of the same forward op who create targets # so targets_gradients only add one grad_var,
# eg: op1 -> op2 -> var1 / var2 targets = var1,
# targets_gradients = var1_grad, need to add var2_grad here.
tmp_targets = targets tmp_targets = targets
if core._is_bwd_prim_enabled(): if core._is_bwd_prim_enabled():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册