未验证 提交 b75d8c7e 编写于 作者: X xiaoguoguo626807 提交者: GitHub

Revert elementwise add (#53745)

* modify concat_grad add sum comp rule

* delete default mul_double_grad

* delete high grad test

* recover yaml

* modify yaml

* recover add_double_grad prim
上级 ce256f75
...@@ -67,7 +67,6 @@ black_ops_list = [ ...@@ -67,7 +67,6 @@ black_ops_list = [
prim_white_list = [ prim_white_list = [
"matmul_double_grad", "matmul_double_grad",
"tanh_double_grad", "tanh_double_grad",
"add_double_grad",
"subtract_double_grad", "subtract_double_grad",
] ]
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
kernel : kernel :
func : add_double_grad func : add_double_grad
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
backward : add_triple_grad
inplace : (grad_x_grad -> grad_out_grad) inplace : (grad_x_grad -> grad_out_grad)
composite : add_double_grad(y, grad_out, grad_x_grad, grad_y_grad, axis, grad_out_grad) composite : add_double_grad(y, grad_out, grad_x_grad, grad_y_grad, axis, grad_out_grad)
...@@ -47,6 +48,17 @@ ...@@ -47,6 +48,17 @@
backward : add_double_grad backward : add_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : add_triple_grad
forward : add_double_grad (Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, int axis = -1) -> Tensor(grad_grad_out)
args : (Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_grad_out_grad, int axis = -1)
output : Tensor(grad_grad_x_grad), Tensor(grad_grad_y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [grad_grad_x, grad_grad_y]
kernel :
func : add_triple_grad
inplace : (grad_grad_out_grad -> grad_grad_x_grad)
- backward_op : amax_grad - backward_op : amax_grad
forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out) forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis={}, bool keepdim=false, bool reduce_all=false) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis={}, bool keepdim=false, bool reduce_all=false)
......
...@@ -25,7 +25,7 @@ import paddle ...@@ -25,7 +25,7 @@ import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
'''
@param.parameterized_class( @param.parameterized_class(
('shape1', 'shape2'), ('shape1', 'shape2'),
[ [
...@@ -120,6 +120,7 @@ class TestAddHighGradCheck(unittest.TestCase): ...@@ -120,6 +120,7 @@ class TestAddHighGradCheck(unittest.TestCase):
for p in places: for p in places:
self.func_double(p) self.func_double(p)
self.func_triple(p) self.func_triple(p)
'''
@param.parameterized_class( @param.parameterized_class(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册