From 94c38803cec36a54b8fa32900674b6fc919d28e1 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Mon, 15 May 2023 16:59:39 +0800 Subject: [PATCH] Silu double grad (#53605) * add rules * modify no kernel yaml parse * success op generate * success test_silu_double * modify bug * modify static error * modify silu_grad input * modify kernel signature * modify kernel signature * code style * code style * review * delete opinfo modify --- .../generator/eager_gen.py | 1 + .../fluid/operators/generator/generate_op.py | 15 + .../operators/generator/generate_sparse_op.py | 2 + .../fluid/operators/generator/parse_utils.py | 17 +- .../operators/generator/templates/op.c.j2 | 2 + .../generator/templates/operator_utils.c.j2 | 3 +- .../fluid/operators/generator/tests_utils.py | 4 + .../composite_backward_api.h | 330 +-------------- .../composite_double_backward_api.h | 387 ++++++++++++++++++ paddle/phi/api/yaml/backward.yaml | 7 + paddle/phi/api/yaml/generator/api_base.py | 31 +- .../api/yaml/generator/backward_api_gen.py | 2 +- paddle/phi/api/yaml/op_compat.yaml | 2 +- test/prim/prim/vjp/test_comp_high_grad.py | 83 ++++ 14 files changed, 529 insertions(+), 357 deletions(-) create mode 100644 paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 8d28fa2438f..b556062343d 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -68,6 +68,7 @@ prim_white_list = [ "matmul_double_grad", "tanh_double_grad", "subtract_double_grad", + "silu_double_grad", ] # dict of special api that forward api's output will affect bacward api's output diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 6f1ea3d8b3c..dad5df7430d 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -41,6 +41,7 @@ from tests_utils import ( is_base_op, is_composite_op, is_initializer_list, + is_only_composite_op, is_scalar, is_vec, supports_inplace, @@ -72,6 +73,7 @@ env.filters["assert_dense_or_sr"] = assert_dense_or_sr env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name env.tests["base_op"] = is_base_op env.tests["composite_op"] = is_composite_op +env.tests["only_composite_op"] = is_only_composite_op env.tests["vec"] = is_vec env.tests["scalar"] = is_scalar env.tests["initializer_list"] = is_initializer_list @@ -165,6 +167,16 @@ def add_composite_info(ops, backward_ops, backward_op_dict): else: op["backward_composite"] = None + # add whether only composite + if ( + op["backward_composite"] is not None + and "invoke" not in backward_op_dict[op["backward"]] + and "kernel" not in backward_op_dict[op["backward"]] + ): + op["only_backward_composite"] = True + else: + op["only_backward_composite"] = False + # add fluid name in ops and backward ops info def add_fluid_name(dict_list): @@ -248,6 +260,9 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): for param in op_item['invoke']['args'].split(',') ] return + elif 'composite' in op_item and 'kernel' not in op_item: + return + op_item['infer_meta']['param'] = get_param_list_alias( op_item['infer_meta']['param'], args_name_map ) diff --git a/paddle/fluid/operators/generator/generate_sparse_op.py b/paddle/fluid/operators/generator/generate_sparse_op.py index 2635f0f67a1..9c92aa3bc3c 100644 --- a/paddle/fluid/operators/generator/generate_sparse_op.py +++ b/paddle/fluid/operators/generator/generate_sparse_op.py @@ -40,6 +40,7 @@ from tests_utils import ( is_base_op, is_composite_op, is_initializer_list, + is_only_composite_op, is_scalar, is_vec, supports_inplace, @@ -71,6 +72,7 @@ env.filters["to_variable_names"] = to_variable_names env.filters["get_infer_var_type_func"] = get_infer_var_type_func env.tests["base_op"] = is_base_op env.tests["composite_op"] = is_composite_op +env.tests["only_composite_op"] = is_only_composite_op env.tests["vec"] = is_vec env.tests["scalar"] = is_scalar env.tests["initializer_list"] = is_initializer_list diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 8a8f4935e1e..d0ee533539d 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -498,8 +498,12 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): "data_transform": data_trans, } - # invokes another op ? - is_base_op = "invoke" not in op_entry + # op should be is_base_op or is_invoke_op or is_only_composite_op + is_base_op = True + if "invoke" in op_entry: + is_base_op = False + if "composite" in op_entry and "kernel" not in op_entry: + is_base_op = False if is_base_op: # kernel @@ -524,10 +528,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): "inplace": inplace_pairs, } ) - else: - # invoke - invoke = parse_invoke(op_name, op_entry["invoke"]) - op["invoke"] = invoke + + # has invoke ? + if "invoke" in op_entry: + invoke_dict = parse_invoke(op_name, op_entry["invoke"]) + op.update({"invoke": invoke_dict}) # has composite ? if "composite" in op_entry: diff --git a/paddle/fluid/operators/generator/templates/op.c.j2 b/paddle/fluid/operators/generator/templates/op.c.j2 index feeb1dee169..68673256a00 100644 --- a/paddle/fluid/operators/generator/templates/op.c.j2 +++ b/paddle/fluid/operators/generator/templates/op.c.j2 @@ -38,6 +38,8 @@ using paddle::framework::GradVarName; {{backward_op_maker(op, op_dict[op["forward"]["name"]])}} {{operator(op)}} + {% elif op is only_composite_op %} + {% else %} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}} {% endif %} diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 0b0190f4e4e..16f0ecaa642 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -472,7 +472,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, {% if not "forward" in op %}{# it is a forward op #} ops::{{name | to_pascal_case}}OpMaker, {% endif %} -{% if "backward" in op and op["backward"] is not none %}{# backward #} +{% if "only_backward_composite" in op and op["only_backward_composite"] is true %}{# backward #} +{% elif "backward" in op and op["backward"] is not none %} {% set backward_name = op["backward"] %} ops::{{backward_name | to_pascal_case}}OpMaker, ops::{{backward_name | to_pascal_case}}OpMaker, diff --git a/paddle/fluid/operators/generator/tests_utils.py b/paddle/fluid/operators/generator/tests_utils.py index 574f3663b7d..9cea1698f4f 100644 --- a/paddle/fluid/operators/generator/tests_utils.py +++ b/paddle/fluid/operators/generator/tests_utils.py @@ -54,6 +54,10 @@ def is_base_op(op): return "kernel" in op and "infer_meta" in op +def is_only_composite_op(op): + return "composite" in op and "kernel" not in op and "invoke" not in op + + def is_composite_op(op): return "composite" in op diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 6eb366dc355..c5d56dc82b5 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -21,6 +21,7 @@ #include #include "paddle/fluid/prim/api/all.h" +#include "paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/int_array.h" @@ -164,26 +165,6 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { set_output(grad_x_tmp, grad_x); } -template -void tanh_double_grad(const Tensor& out, - const Tensor& grad_out, - const Tensor& grad_x_grad, - Tensor* out_grad, - Tensor* grad_out_grad) { - // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out * - // ddx) - auto out_m_grad_x_grad = out * grad_x_grad; - if (out_grad) { - auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad; - set_output(out_grad_tmp, out_grad); - } - - if (grad_out_grad) { - auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad; - set_output(grad_out_grad_tmp, grad_out_grad); - } -} - template void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { if (grad_x) { @@ -698,315 +679,6 @@ void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { } } -template -void matmul_double_grad(const Tensor& x, - const Tensor& y, - const Tensor& grad_out, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - bool transpose_x, - bool transpose_y, - Tensor* x_grad, - Tensor* y_grad, - Tensor* grad_out_grad) { - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector grad_out_dims = vectorize(grad_out.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int dout_ndim = grad_out_dims.size(); - - // prepare dims for x_ndim <= 1 || y_ndim <= 1 - Tensor x_help, y_help, xg_help, yg_help, out_help; - - if (x_ndim == 1 && y_ndim == 1) { - transpose_x = false; - transpose_y = false; - x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); - y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); - if (grad_x_grad) { - xg_help = reshape(grad_x_grad.get(), - IntArray(std::vector({1, x_dims[0]}))); - } - if (grad_y_grad) { - yg_help = reshape(grad_y_grad.get(), - IntArray(std::vector({y_dims[0], 1}))); - } - out_help = reshape(grad_out, IntArray(std::vector({1, 1}))); - - } else if (x_ndim == 1) { - transpose_x = false; - x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); - y_help = y; - if (grad_x_grad) { - xg_help = reshape(grad_x_grad.get(), - IntArray(std::vector({1, x_dims[0]}))); - } - if (grad_y_grad) { - yg_help = grad_y_grad.get(); - } - auto tmp_grad_out_dims = grad_out_dims; - tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1); - out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); - - } else if (y_ndim == 1) { - transpose_y = false; - x_help = x; - y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); - if (grad_x_grad) { - xg_help = grad_x_grad.get(); - } - if (grad_y_grad) { - yg_help = reshape(grad_y_grad.get(), - IntArray(std::vector({y_dims[0], 1}))); - } - auto tmp_grad_out_dims = grad_out_dims; - tmp_grad_out_dims.push_back(1); - out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); - - } else { - x_help = x; - y_help = y; - if (grad_x_grad) { - xg_help = grad_x_grad.get(); - } - if (grad_y_grad) { - yg_help = grad_y_grad.get(); - } - out_help = grad_out; - } - - bool is_broadcast = true; - if (x_ndim <= 2 && y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - Tensor dx, dy, ddout_1, ddout_2, ddout; - if (!grad_x_grad && !grad_y_grad) { - x_grad = nullptr; - y_grad = nullptr; - grad_out_grad = nullptr; - return; - - } else if (!grad_x_grad) { - y_grad = nullptr; - if (!transpose_x && !transpose_y) { - if (x_grad) { - dx = matmul(out_help, yg_help, false, true); - } - if (grad_out_grad) { - ddout = matmul(x_help, yg_help, false, false); - } - } else if (!transpose_x && transpose_y) { - if (x_grad) { - dx = matmul(out_help, yg_help, false, false); - } - if (grad_out_grad) { - ddout = matmul(x_help, yg_help, false, true); - } - } else if (transpose_x && !transpose_y) { - if (x_grad) { - dx = matmul(yg_help, out_help, false, true); - } - if (grad_out_grad) { - ddout = matmul(x_help, yg_help, true, false); - } - } else { - if (x_grad) { - dx = matmul(yg_help, out_help, true, true); - } - if (grad_out_grad) { - ddout = matmul(x_help, yg_help, true, true); - } - } - - } else if (!grad_y_grad) { - x_grad = nullptr; - if (!transpose_x && !transpose_y) { - if (y_grad) { - dy = matmul(xg_help, out_help, true, false); - } - if (grad_out_grad) { - ddout = matmul(xg_help, y_help, false, false); - } - } else if (!transpose_x && transpose_y) { - if (y_grad) { - dy = matmul(out_help, xg_help, true, false); - } - if (grad_out_grad) { - ddout = matmul(xg_help, y_help, false, true); - } - } else if (transpose_x && !transpose_y) { - if (y_grad) { - dy = matmul(xg_help, out_help, false, false); - } - if (grad_out_grad) { - ddout = matmul(xg_help, y_help, true, false); - } - } else { - if (y_grad) { - dy = matmul(out_help, xg_help, true, true); - } - if (grad_out_grad) { - ddout = matmul(xg_help, y_help, true, true); - } - } - - } else { - if (!transpose_x && !transpose_y) { - if (x_grad) { - dx = matmul(out_help, yg_help, false, true); - } - if (y_grad) { - dy = matmul(xg_help, out_help, true, false); - } - if (grad_out_grad) { - ddout_1 = matmul(x_help, yg_help, false, false); - ddout_2 = matmul(xg_help, y_help, false, false); - ddout = add(ddout_1, ddout_2); - } - } else if (!transpose_x && transpose_y) { - if (x_grad) { - dx = matmul(out_help, yg_help, false, false); - } - - if (y_grad) { - dy = matmul(out_help, xg_help, true, false); - } - if (grad_out_grad) { - ddout_1 = matmul(x_help, yg_help, false, true); - ddout_2 = matmul(xg_help, y_help, false, true); - ddout = add(ddout_1, ddout_2); - } - } else if (transpose_x && !transpose_y) { - if (x_grad) { - dx = matmul(yg_help, out_help, false, true); - } - - if (y_grad) { - dy = matmul(xg_help, out_help, false, false); - } - if (grad_out_grad) { - ddout_1 = matmul(x_help, yg_help, true, false); - ddout_2 = matmul(xg_help, y_help, true, false); - ddout = add(ddout_1, ddout_2); - } - } else { - if (x_grad) { - dx = matmul(yg_help, out_help, true, true); - } - if (y_grad) { - dy = matmul(out_help, xg_help, true, true); - } - if (grad_out_grad) { - ddout_1 = matmul(x_help, yg_help, true, true); - ddout_2 = matmul(xg_help, y_help, true, true); - ddout = add(ddout_1, ddout_2); - } - } - } - - if (is_broadcast) { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - // Reduce sum to get grad by ReduceSum - if (x_grad) { - auto tx_dims = x_dims; - auto tx_ndim = x_ndim; - auto tdout_ndim = dout_ndim; - if (x_ndim == 1) { - tx_dims = std::vector({1, x_dims[0]}); - tx_ndim = x_ndim + 1; - tdout_ndim = dout_ndim + 1; - } - - auto x_grad_reduce_dims = - get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims); - - if (!x_grad_reduce_dims.empty()) { - dx = sum(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true); - } - reshape(dx, IntArray(tx_dims)); - } - - if (y_grad) { - auto ty_dims = y_dims; - auto ty_ndim = y_ndim; - auto tdout_ndim = dout_ndim; - if (y_ndim == 1) { - ty_dims = std::vector({y_dims[0], 1}); - ty_ndim = y_ndim + 1; - tdout_ndim = dout_ndim + 1; - } - - auto y_grad_reduce_dims = - get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims); - - if (!y_grad_reduce_dims.empty()) { - dy = sum(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true); - } - reshape(dy, IntArray(ty_dims)); - } - } - - // recover the original dim of output (delete 1) - std::vector dx_dims = - dx.initialized() ? vectorize(dx.dims()) : std::vector({}); - std::vector dy_dims = - dy.initialized() ? vectorize(dy.dims()) : std::vector({}); - std::vector ddout_dims = - ddout.initialized() ? vectorize(ddout.dims()) : std::vector({}); - if (x_ndim == 1 && y_ndim == 1) { - if (dx.initialized() && dx_dims[0] == 1) { - dx = reshape(dx, IntArray(x_dims)); - } - if (dy.initialized() && dy_dims.back() == 1) { - dy = reshape(dy, IntArray(y_dims)); - } - if (ddout.initialized() && ddout_dims == std::vector({1, 1})) { - ddout = reshape(ddout, IntArray(std::vector({1}))); - } - } else if (x_ndim == 1) { - if (dx.initialized() && dx_dims[0] == 1) { - dx = reshape(dx, IntArray(x_dims)); - } - if (ddout.initialized() && ddout_dims[0] == 1) { - ddout = reshape(ddout, - IntArray(std::vector( - {ddout_dims.cbegin() + 1, ddout_dims.cend()}))); - } - } else if (y_ndim == 1) { - if (dy.initialized() && dy_dims.back() == 1) { - dy = reshape(dy, IntArray(y_dims)); - } - if (ddout.initialized() && ddout_dims.back() == 1) { - ddout = reshape(ddout, - IntArray(std::vector( - {ddout_dims.cbegin(), - ddout_dims.cbegin() + ddout_dims.size() - 1}))); - } - } - - if (x_grad) { - set_output(dx, x_grad); - } - if (y_grad) { - set_output(dy, y_grad); - } - if (grad_out_grad) { - set_output(ddout, grad_out_grad); - } -} - template void slice_grad(const Tensor& input, const Tensor& out_grad, diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h new file mode 100644 index 00000000000..b97ffeb4b26 --- /dev/null +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -0,0 +1,387 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include + +#include "paddle/fluid/prim/api/all.h" +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/ddim.h" + +namespace paddle { +namespace prim { +using Tensor = paddle::Tensor; +using IntArray = paddle::experimental::IntArrayBase; +// This file define high level grad composite api for Higher order +// differentiation + +template +void tanh_double_grad(const Tensor& out, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* out_grad, + Tensor* grad_out_grad) { + // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out * + // ddx) + auto out_m_grad_x_grad = out * grad_x_grad; + if (out_grad) { + auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad; + set_output(out_grad_tmp, out_grad); + } + + if (grad_out_grad) { + auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + +template +void matmul_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + bool transpose_x, + bool transpose_y, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector grad_out_dims = vectorize(grad_out.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int dout_ndim = grad_out_dims.size(); + + // prepare dims for x_ndim <= 1 || y_ndim <= 1 + Tensor x_help, y_help, xg_help, yg_help, out_help; + + if (x_ndim == 1 && y_ndim == 1) { + transpose_x = false; + transpose_y = false; + x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); + y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); + if (grad_x_grad) { + xg_help = reshape(grad_x_grad.get(), + IntArray(std::vector({1, x_dims[0]}))); + } + if (grad_y_grad) { + yg_help = reshape(grad_y_grad.get(), + IntArray(std::vector({y_dims[0], 1}))); + } + out_help = reshape(grad_out, IntArray(std::vector({1, 1}))); + + } else if (x_ndim == 1) { + transpose_x = false; + x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); + y_help = y; + if (grad_x_grad) { + xg_help = reshape(grad_x_grad.get(), + IntArray(std::vector({1, x_dims[0]}))); + } + if (grad_y_grad) { + yg_help = grad_y_grad.get(); + } + auto tmp_grad_out_dims = grad_out_dims; + tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1); + out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); + + } else if (y_ndim == 1) { + transpose_y = false; + x_help = x; + y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); + if (grad_x_grad) { + xg_help = grad_x_grad.get(); + } + if (grad_y_grad) { + yg_help = reshape(grad_y_grad.get(), + IntArray(std::vector({y_dims[0], 1}))); + } + auto tmp_grad_out_dims = grad_out_dims; + tmp_grad_out_dims.push_back(1); + out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); + + } else { + x_help = x; + y_help = y; + if (grad_x_grad) { + xg_help = grad_x_grad.get(); + } + if (grad_y_grad) { + yg_help = grad_y_grad.get(); + } + out_help = grad_out; + } + + bool is_broadcast = true; + if (x_ndim <= 2 && y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + Tensor dx, dy, ddout_1, ddout_2, ddout; + if (!grad_x_grad && !grad_y_grad) { + x_grad = nullptr; + y_grad = nullptr; + grad_out_grad = nullptr; + return; + + } else if (!grad_x_grad) { + y_grad = nullptr; + if (!transpose_x && !transpose_y) { + if (x_grad) { + dx = matmul(out_help, yg_help, false, true); + } + if (grad_out_grad) { + ddout = matmul(x_help, yg_help, false, false); + } + } else if (!transpose_x && transpose_y) { + if (x_grad) { + dx = matmul(out_help, yg_help, false, false); + } + if (grad_out_grad) { + ddout = matmul(x_help, yg_help, false, true); + } + } else if (transpose_x && !transpose_y) { + if (x_grad) { + dx = matmul(yg_help, out_help, false, true); + } + if (grad_out_grad) { + ddout = matmul(x_help, yg_help, true, false); + } + } else { + if (x_grad) { + dx = matmul(yg_help, out_help, true, true); + } + if (grad_out_grad) { + ddout = matmul(x_help, yg_help, true, true); + } + } + + } else if (!grad_y_grad) { + x_grad = nullptr; + if (!transpose_x && !transpose_y) { + if (y_grad) { + dy = matmul(xg_help, out_help, true, false); + } + if (grad_out_grad) { + ddout = matmul(xg_help, y_help, false, false); + } + } else if (!transpose_x && transpose_y) { + if (y_grad) { + dy = matmul(out_help, xg_help, true, false); + } + if (grad_out_grad) { + ddout = matmul(xg_help, y_help, false, true); + } + } else if (transpose_x && !transpose_y) { + if (y_grad) { + dy = matmul(xg_help, out_help, false, false); + } + if (grad_out_grad) { + ddout = matmul(xg_help, y_help, true, false); + } + } else { + if (y_grad) { + dy = matmul(out_help, xg_help, true, true); + } + if (grad_out_grad) { + ddout = matmul(xg_help, y_help, true, true); + } + } + + } else { + if (!transpose_x && !transpose_y) { + if (x_grad) { + dx = matmul(out_help, yg_help, false, true); + } + if (y_grad) { + dy = matmul(xg_help, out_help, true, false); + } + if (grad_out_grad) { + ddout_1 = matmul(x_help, yg_help, false, false); + ddout_2 = matmul(xg_help, y_help, false, false); + ddout = add(ddout_1, ddout_2); + } + } else if (!transpose_x && transpose_y) { + if (x_grad) { + dx = matmul(out_help, yg_help, false, false); + } + + if (y_grad) { + dy = matmul(out_help, xg_help, true, false); + } + if (grad_out_grad) { + ddout_1 = matmul(x_help, yg_help, false, true); + ddout_2 = matmul(xg_help, y_help, false, true); + ddout = add(ddout_1, ddout_2); + } + } else if (transpose_x && !transpose_y) { + if (x_grad) { + dx = matmul(yg_help, out_help, false, true); + } + + if (y_grad) { + dy = matmul(xg_help, out_help, false, false); + } + if (grad_out_grad) { + ddout_1 = matmul(x_help, yg_help, true, false); + ddout_2 = matmul(xg_help, y_help, true, false); + ddout = add(ddout_1, ddout_2); + } + } else { + if (x_grad) { + dx = matmul(yg_help, out_help, true, true); + } + if (y_grad) { + dy = matmul(out_help, xg_help, true, true); + } + if (grad_out_grad) { + ddout_1 = matmul(x_help, yg_help, true, true); + ddout_2 = matmul(xg_help, y_help, true, true); + ddout = add(ddout_1, ddout_2); + } + } + } + + if (is_broadcast) { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + // Reduce sum to get grad by ReduceSum + if (x_grad) { + auto tx_dims = x_dims; + auto tx_ndim = x_ndim; + auto tdout_ndim = dout_ndim; + if (x_ndim == 1) { + tx_dims = std::vector({1, x_dims[0]}); + tx_ndim = x_ndim + 1; + tdout_ndim = dout_ndim + 1; + } + + auto x_grad_reduce_dims = + get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims); + + if (!x_grad_reduce_dims.empty()) { + dx = sum(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true); + } + reshape(dx, IntArray(tx_dims)); + } + + if (y_grad) { + auto ty_dims = y_dims; + auto ty_ndim = y_ndim; + auto tdout_ndim = dout_ndim; + if (y_ndim == 1) { + ty_dims = std::vector({y_dims[0], 1}); + ty_ndim = y_ndim + 1; + tdout_ndim = dout_ndim + 1; + } + + auto y_grad_reduce_dims = + get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims); + + if (!y_grad_reduce_dims.empty()) { + dy = sum(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true); + } + reshape(dy, IntArray(ty_dims)); + } + } + + // recover the original dim of output (delete 1) + std::vector dx_dims = + dx.initialized() ? vectorize(dx.dims()) : std::vector({}); + std::vector dy_dims = + dy.initialized() ? vectorize(dy.dims()) : std::vector({}); + std::vector ddout_dims = + ddout.initialized() ? vectorize(ddout.dims()) : std::vector({}); + if (x_ndim == 1 && y_ndim == 1) { + if (dx.initialized() && dx_dims[0] == 1) { + dx = reshape(dx, IntArray(x_dims)); + } + if (dy.initialized() && dy_dims.back() == 1) { + dy = reshape(dy, IntArray(y_dims)); + } + if (ddout.initialized() && ddout_dims == std::vector({1, 1})) { + ddout = reshape(ddout, IntArray(std::vector({1}))); + } + } else if (x_ndim == 1) { + if (dx.initialized() && dx_dims[0] == 1) { + dx = reshape(dx, IntArray(x_dims)); + } + if (ddout.initialized() && ddout_dims[0] == 1) { + ddout = reshape(ddout, + IntArray(std::vector( + {ddout_dims.cbegin() + 1, ddout_dims.cend()}))); + } + } else if (y_ndim == 1) { + if (dy.initialized() && dy_dims.back() == 1) { + dy = reshape(dy, IntArray(y_dims)); + } + if (ddout.initialized() && ddout_dims.back() == 1) { + ddout = reshape(ddout, + IntArray(std::vector( + {ddout_dims.cbegin(), + ddout_dims.cbegin() + ddout_dims.size() - 1}))); + } + } + + if (x_grad) { + set_output(dx, x_grad); + } + if (y_grad) { + set_output(dy, y_grad); + } + if (grad_out_grad) { + set_output(ddout, grad_out_grad); + } +} + +template +void silu_double_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const Tensor& grad_x_grad, + Tensor* grad_x, + Tensor* grad_out_grad) { + auto sigmoid = out / x; + auto tmp1 = 1 - sigmoid; + auto tmp2 = 1 + tmp1 * x; + if (grad_out_grad) { + auto ddout = grad_x_grad * sigmoid * tmp2; + set_output(ddout, grad_out_grad); + } + if (grad_x) { + auto dx = + sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - x * sigmoid)) * tmp1; + set_output(dx, grad_x); + } +} + +} // namespace prim +} // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index e3ef0460884..92ce63e4709 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1626,6 +1626,7 @@ param : [x] kernel : func : silu_grad + backward : silu_double_grad composite : silu_grad(x, out, out_grad, x_grad) inplace : (out_grad -> x_grad) @@ -2106,6 +2107,12 @@ func : yolo_loss_grad optional : gt_score +- backward_op: silu_double_grad + forward: silu_grad (Tensor x, Tensor out, Tensor grad_out) -> Tensor(grad_x) + args: (Tensor x, Tensor out, Tensor grad_out, Tensor grad_x_grad) + output: Tensor(x_grad), Tensor(grad_out_grad) + composite: silu_double_grad(x, out, grad_out, grad_x_grad, x_grad, grad_out_grad) + - backward_op: unpool3d_grad forward: unpool3d (Tensor x, Tensor indices, int[] ksize, int[] strides={1,1,1}, int[] paddings={0,0,0}, int[] output_size={0,0,0}, str data_format="NCDHW") -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] paddings, int[] output_size, str data_format) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 4d41bb522e0..e4ac5726b9c 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -41,6 +41,7 @@ class BaseAPI: ) = self.parse_args(self.api, api_item_yaml) self.is_base_api = True + self.is_only_composite_api = False if 'invoke' in api_item_yaml: self.is_base_api = False self.invoke = api_item_yaml['invoke'] @@ -49,7 +50,12 @@ class BaseAPI: self.infer_meta = self.parse_infer_meta( api_item_yaml['infer_meta'] ) - self.kernel = self.parse_kernel(api_item_yaml['kernel']) + if 'composite' in api_item_yaml and 'kernel' not in api_item_yaml: + self.is_base_api = False + self.is_only_composite_api = True + self.kernel = None + else: + self.kernel = self.parse_kernel(api_item_yaml['kernel']) self.data_transform = self.parse_data_transform(api_item_yaml) self.inplace_map, self.view_map = {}, {} @@ -1319,23 +1325,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ api_code = "" api_code = api_code + self.gene_base_api_code(inplace_flag=True) return api_code - + elif self.is_only_composite_api: + # for composite and invoke api, dygraph use prim::xxx_grad method + return '' else: - invoke_func_name = self.invoke.split('(')[0].strip() - if invoke_func_name in self.attrs['names']: - # Adjust the param whose name is same with api invoked. - pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]' - - def adjust_name(matched): - matched_str = matched.group() - return matched_str[0:-1] + '_val' + matched_str[-1] - - invoke_code = re.sub(pattern, adjust_name, self.invoke) - params_code = re.sub( - pattern, adjust_name, self.get_define_args() - ) - else: - invoke_code = self.invoke - params_code = self.get_define_args() - + invoke_code = self.invoke + params_code = self.get_define_args() return self.gene_invoke_code(invoke_code, params_code) diff --git a/paddle/phi/api/yaml/generator/backward_api_gen.py b/paddle/phi/api/yaml/generator/backward_api_gen.py index 36ac38a88dd..432f70f8dba 100644 --- a/paddle/phi/api/yaml/generator/backward_api_gen.py +++ b/paddle/phi/api/yaml/generator/backward_api_gen.py @@ -112,7 +112,7 @@ class BackwardAPI(BaseAPI): return "" def gene_api_declaration(self): - if not self.is_base_api: + if not self.is_base_api and not self.is_only_composite_api: invoke_func_name = self.invoke.split('(')[0] if (not invoke_func_name.endswith("_grad")) and ( not invoke_func_name.endswith('_impl') diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 74bdb3c53dd..5ad73f4e104 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2101,7 +2101,7 @@ out : Out - op : silu - backward : silu_grad + backward : silu_grad, silu_double_grad inputs : x : X outputs : diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 99268b1b58e..489ae6e6f1f 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -331,5 +331,88 @@ class TestMultiplyHighGradCheck(unittest.TestCase): self.func_double(p) self.func_triple(p) ''' + + +@param.parameterized_class( + ('shape1'), + [ + ([2],), + ([2, 3],), + ([2, 3, 4],), + ([2, 3, 3, 4],), + ], +) +class TestSiluHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + + def silu_wrapper(self, x): + return paddle.nn.functional.silu(x[0]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.nn.functional.silu(x) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + + # silu double grad only has CompositeOpMaker,don't need set prim_flag + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.silu_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.nn.functional.silu(x) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.silu_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + if __name__ == '__main__': unittest.main() -- GitLab