未验证 提交 b4024aaf 编写于 作者: X xiaoguoguo626807 提交者: GitHub

Revert elementwise (#53663)

* modify concat_grad add sum comp rule

* delete default mul_double_grad

* delete high grad test

* recover yaml

* modify yaml
上级 314d0418
......@@ -68,7 +68,6 @@ prim_white_list = [
"matmul_double_grad",
"tanh_double_grad",
"add_double_grad",
"multiply_double_grad",
"subtract_double_grad",
]
......
......@@ -617,6 +617,7 @@
func : multiply_double_grad
optional : grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)
backward : multiply_triple_grad
composite : multiply_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, axis, x_grad, y_grad, grad_out_grad)
- backward_op : multiply_grad
......@@ -631,6 +632,17 @@
composite: multiply_grad(x, y, out_grad, axis, x_grad, y_grad)
backward : multiply_double_grad
- backward_op : multiply_triple_grad
forward : multiply_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, int aixs = -1) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, int axis = -1)
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : norm_grad
forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm)
args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test)
......
......@@ -226,6 +226,7 @@ class TestSubtractHighGradCheck(unittest.TestCase):
self.func_triple(p)
'''
@param.parameterized_class(
('shape1', 'shape2'),
[
......@@ -328,7 +329,6 @@ class TestMultiplyHighGradCheck(unittest.TestCase):
for p in places:
self.func_double(p)
self.func_triple(p)
'''
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册