未验证 提交 25726b94 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[prim] modify multiply_double_grad composite rules to solve paddlescience...

[prim] modify multiply_double_grad composite rules to solve paddlescience poiseuille_flow.py problem (#54908)

* modify eular_beam

* modify matmul infermeta

* add test

* modify timeout

* modify mutiply_double nullptr

* modify tanh_triple_gradnode create
上级 19345fa7
......@@ -412,7 +412,8 @@ void multiply_double_grad(const Tensor& x,
}
} else {
x_grad = nullptr;
auto dx = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
set_output<T>(dx, x_grad);
}
}
if (y_grad) {
......@@ -433,22 +434,22 @@ void multiply_double_grad(const Tensor& x,
set_output<T>(dy, y_grad);
}
} else {
y_grad = nullptr;
auto dy = full<T>(phi::vectorize(y.dims()), 0.0, y.dtype());
set_output<T>(dy, y_grad);
}
}
if (grad_out_grad) {
Tensor ddout;
if (grad_x_grad && grad_y_grad) {
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
} else if (grad_x_grad) {
auto ddout = grad_x_grad.get() * y;
set_output<T>(ddout, grad_out_grad);
ddout = grad_x_grad.get() * y;
} else if (grad_y_grad) {
auto ddout = grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
ddout = grad_y_grad.get() * x;
} else {
grad_out_grad = nullptr;
ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, grad_out.dtype());
}
set_output<T>(ddout, grad_out_grad);
}
}
......@@ -461,10 +462,10 @@ void add_double_grad(const Tensor& y,
Tensor* grad_out_grad) {
if (grad_out_grad) {
// ddout = ddx + ddy
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr;
set_output<T>(ddout, grad_out_grad);
} else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) {
ddout = ddout + grad_x_grad.get();
}
......
......@@ -2185,7 +2185,6 @@
func : tanh_double_grad
composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad)
inplace : (grad_x_grad -> grad_out_grad)
backward : tanh_triple_grad
- backward_op : tanh_grad
forward : tanh (Tensor x) -> Tensor(out)
......@@ -2211,18 +2210,6 @@
func : tanh_shrink_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_triple_grad
forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad)
args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad)
output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, out, grad_x_grad_forward]
kernel :
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad
- backward_op : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW") -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format)
......
......@@ -2573,7 +2573,7 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : tanh
backward : tanh_grad, tanh_double_grad (tanh_grad_grad), tanh_triple_grad
backward : tanh_grad, tanh_double_grad (tanh_grad_grad)
inputs :
x : X
outputs :
......
......@@ -2489,9 +2489,11 @@ def calc_gradient_helper(
block, targets, inputs, block_no_grad_set, op_path_dict
)
# only for composite to add grad_op input,
# tmp_targets includes targets and other outputs
# of the same forward op who create targets
# only for composite to add grad_var of the last forward op
# who has more than one output, but targets only has one,
# so targets_gradients only add one grad_var,
# eg: op1 -> op2 -> var1 / var2 targets = var1,
# targets_gradients = var1_grad, need to add var2_grad here.
tmp_targets = targets
if core._is_bwd_prim_enabled():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册