未验证 提交 94c38803 编写于 作者: X xiaoguoguo626807 提交者: GitHub

Silu double grad (#53605)

* add rules

* modify no kernel yaml parse

* success op generate

* success test_silu_double

* modify bug

* modify static error

* modify silu_grad input

* modify kernel signature

* modify kernel signature

* code style

* code style

* review

* delete opinfo modify
上级 0ef51804
...@@ -68,6 +68,7 @@ prim_white_list = [ ...@@ -68,6 +68,7 @@ prim_white_list = [
"matmul_double_grad", "matmul_double_grad",
"tanh_double_grad", "tanh_double_grad",
"subtract_double_grad", "subtract_double_grad",
"silu_double_grad",
] ]
# dict of special api that forward api's output will affect bacward api's output # dict of special api that forward api's output will affect bacward api's output
......
...@@ -41,6 +41,7 @@ from tests_utils import ( ...@@ -41,6 +41,7 @@ from tests_utils import (
is_base_op, is_base_op,
is_composite_op, is_composite_op,
is_initializer_list, is_initializer_list,
is_only_composite_op,
is_scalar, is_scalar,
is_vec, is_vec,
supports_inplace, supports_inplace,
...@@ -72,6 +73,7 @@ env.filters["assert_dense_or_sr"] = assert_dense_or_sr ...@@ -72,6 +73,7 @@ env.filters["assert_dense_or_sr"] = assert_dense_or_sr
env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name
env.tests["base_op"] = is_base_op env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op env.tests["composite_op"] = is_composite_op
env.tests["only_composite_op"] = is_only_composite_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -165,6 +167,16 @@ def add_composite_info(ops, backward_ops, backward_op_dict): ...@@ -165,6 +167,16 @@ def add_composite_info(ops, backward_ops, backward_op_dict):
else: else:
op["backward_composite"] = None op["backward_composite"] = None
# add whether only composite
if (
op["backward_composite"] is not None
and "invoke" not in backward_op_dict[op["backward"]]
and "kernel" not in backward_op_dict[op["backward"]]
):
op["only_backward_composite"] = True
else:
op["only_backward_composite"] = False
# add fluid name in ops and backward ops info # add fluid name in ops and backward ops info
def add_fluid_name(dict_list): def add_fluid_name(dict_list):
...@@ -248,6 +260,9 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -248,6 +260,9 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
for param in op_item['invoke']['args'].split(',') for param in op_item['invoke']['args'].split(',')
] ]
return return
elif 'composite' in op_item and 'kernel' not in op_item:
return
op_item['infer_meta']['param'] = get_param_list_alias( op_item['infer_meta']['param'] = get_param_list_alias(
op_item['infer_meta']['param'], args_name_map op_item['infer_meta']['param'], args_name_map
) )
......
...@@ -40,6 +40,7 @@ from tests_utils import ( ...@@ -40,6 +40,7 @@ from tests_utils import (
is_base_op, is_base_op,
is_composite_op, is_composite_op,
is_initializer_list, is_initializer_list,
is_only_composite_op,
is_scalar, is_scalar,
is_vec, is_vec,
supports_inplace, supports_inplace,
...@@ -71,6 +72,7 @@ env.filters["to_variable_names"] = to_variable_names ...@@ -71,6 +72,7 @@ env.filters["to_variable_names"] = to_variable_names
env.filters["get_infer_var_type_func"] = get_infer_var_type_func env.filters["get_infer_var_type_func"] = get_infer_var_type_func
env.tests["base_op"] = is_base_op env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op env.tests["composite_op"] = is_composite_op
env.tests["only_composite_op"] = is_only_composite_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
......
...@@ -498,8 +498,12 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): ...@@ -498,8 +498,12 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
"data_transform": data_trans, "data_transform": data_trans,
} }
# invokes another op ? # op should be is_base_op or is_invoke_op or is_only_composite_op
is_base_op = "invoke" not in op_entry is_base_op = True
if "invoke" in op_entry:
is_base_op = False
if "composite" in op_entry and "kernel" not in op_entry:
is_base_op = False
if is_base_op: if is_base_op:
# kernel # kernel
...@@ -524,10 +528,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): ...@@ -524,10 +528,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
"inplace": inplace_pairs, "inplace": inplace_pairs,
} }
) )
else:
# invoke # has invoke ?
invoke = parse_invoke(op_name, op_entry["invoke"]) if "invoke" in op_entry:
op["invoke"] = invoke invoke_dict = parse_invoke(op_name, op_entry["invoke"])
op.update({"invoke": invoke_dict})
# has composite ? # has composite ?
if "composite" in op_entry: if "composite" in op_entry:
......
...@@ -38,6 +38,8 @@ using paddle::framework::GradVarName; ...@@ -38,6 +38,8 @@ using paddle::framework::GradVarName;
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}} {{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(op)}} {{operator(op)}}
{% elif op is only_composite_op %}
{% else %} {% else %}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
......
...@@ -472,7 +472,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -472,7 +472,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in op %}{# it is a forward op #} {% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker, ops::{{name | to_pascal_case}}OpMaker,
{% endif %} {% endif %}
{% if "backward" in op and op["backward"] is not none %}{# backward #} {% if "only_backward_composite" in op and op["only_backward_composite"] is true %}{# backward #}
{% elif "backward" in op and op["backward"] is not none %}
{% set backward_name = op["backward"] %} {% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
......
...@@ -54,6 +54,10 @@ def is_base_op(op): ...@@ -54,6 +54,10 @@ def is_base_op(op):
return "kernel" in op and "infer_meta" in op return "kernel" in op and "infer_meta" in op
def is_only_composite_op(op):
return "composite" in op and "kernel" not in op and "invoke" not in op
def is_composite_op(op): def is_composite_op(op):
return "composite" in op return "composite" in op
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <math.h> #include <math.h>
#include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
...@@ -164,26 +165,6 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { ...@@ -164,26 +165,6 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
set_output<T>(grad_x_tmp, grad_x); set_output<T>(grad_x_tmp, grad_x);
} }
template <typename T>
void tanh_double_grad(const Tensor& out,
const Tensor& grad_out,
const Tensor& grad_x_grad,
Tensor* out_grad,
Tensor* grad_out_grad) {
// tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out *
// ddx)
auto out_m_grad_x_grad = out * grad_x_grad;
if (out_grad) {
auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad;
set_output<T>(out_grad_tmp, out_grad);
}
if (grad_out_grad) {
auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad;
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}
template <typename T> template <typename T>
void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
if (grad_x) { if (grad_x) {
...@@ -698,315 +679,6 @@ void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { ...@@ -698,315 +679,6 @@ void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
} }
} }
template <typename T>
void matmul_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
bool transpose_x,
bool transpose_y,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> grad_out_dims = vectorize(grad_out.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int dout_ndim = grad_out_dims.size();
// prepare dims for x_ndim <= 1 || y_ndim <= 1
Tensor x_help, y_help, xg_help, yg_help, out_help;
if (x_ndim == 1 && y_ndim == 1) {
transpose_x = false;
transpose_y = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
out_help = reshape<T>(grad_out, IntArray(std::vector<int64_t>({1, 1})));
} else if (x_ndim == 1) {
transpose_x = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = y;
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else if (y_ndim == 1) {
transpose_y = false;
x_help = x;
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.push_back(1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else {
x_help = x;
y_help = y;
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
out_help = grad_out;
}
bool is_broadcast = true;
if (x_ndim <= 2 && y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
Tensor dx, dy, ddout_1, ddout_2, ddout;
if (!grad_x_grad && !grad_y_grad) {
x_grad = nullptr;
y_grad = nullptr;
grad_out_grad = nullptr;
return;
} else if (!grad_x_grad) {
y_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, false);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, true);
}
}
} else if (!grad_y_grad) {
x_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, false);
}
} else {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, true);
}
}
} else {
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, false);
ddout_2 = matmul<T>(xg_help, y_help, false, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, true);
ddout_2 = matmul<T>(xg_help, y_help, false, true);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, false);
ddout_2 = matmul<T>(xg_help, y_help, true, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, true);
ddout_2 = matmul<T>(xg_help, y_help, true, true);
ddout = add<T>(ddout_1, ddout_2);
}
}
}
if (is_broadcast) {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
// Reduce sum to get grad by ReduceSum
if (x_grad) {
auto tx_dims = x_dims;
auto tx_ndim = x_ndim;
auto tdout_ndim = dout_ndim;
if (x_ndim == 1) {
tx_dims = std::vector<int64_t>({1, x_dims[0]});
tx_ndim = x_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto x_grad_reduce_dims =
get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims);
if (!x_grad_reduce_dims.empty()) {
dx = sum<T>(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dx, IntArray(tx_dims));
}
if (y_grad) {
auto ty_dims = y_dims;
auto ty_ndim = y_ndim;
auto tdout_ndim = dout_ndim;
if (y_ndim == 1) {
ty_dims = std::vector<int64_t>({y_dims[0], 1});
ty_ndim = y_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto y_grad_reduce_dims =
get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims);
if (!y_grad_reduce_dims.empty()) {
dy = sum<T>(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dy, IntArray(ty_dims));
}
}
// recover the original dim of output (delete 1)
std::vector<int64_t> dx_dims =
dx.initialized() ? vectorize(dx.dims()) : std::vector<int64_t>({});
std::vector<int64_t> dy_dims =
dy.initialized() ? vectorize(dy.dims()) : std::vector<int64_t>({});
std::vector<int64_t> ddout_dims =
ddout.initialized() ? vectorize(ddout.dims()) : std::vector<int64_t>({});
if (x_ndim == 1 && y_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims == std::vector<int64_t>({1, 1})) {
ddout = reshape<T>(ddout, IntArray(std::vector<int64_t>({1})));
}
} else if (x_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (ddout.initialized() && ddout_dims[0] == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin() + 1, ddout_dims.cend()})));
}
} else if (y_ndim == 1) {
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims.back() == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin(),
ddout_dims.cbegin() + ddout_dims.size() - 1})));
}
}
if (x_grad) {
set_output<T>(dx, x_grad);
}
if (y_grad) {
set_output<T>(dy, y_grad);
}
if (grad_out_grad) {
set_output<T>(ddout, grad_out_grad);
}
}
template <typename T> template <typename T>
void slice_grad(const Tensor& input, void slice_grad(const Tensor& input,
const Tensor& out_grad, const Tensor& out_grad,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace prim {
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
// This file define high level grad composite api for Higher order
// differentiation
template <typename T>
void tanh_double_grad(const Tensor& out,
const Tensor& grad_out,
const Tensor& grad_x_grad,
Tensor* out_grad,
Tensor* grad_out_grad) {
// tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out *
// ddx)
auto out_m_grad_x_grad = out * grad_x_grad;
if (out_grad) {
auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad;
set_output<T>(out_grad_tmp, out_grad);
}
if (grad_out_grad) {
auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad;
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}
template <typename T>
void matmul_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
bool transpose_x,
bool transpose_y,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> grad_out_dims = vectorize(grad_out.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int dout_ndim = grad_out_dims.size();
// prepare dims for x_ndim <= 1 || y_ndim <= 1
Tensor x_help, y_help, xg_help, yg_help, out_help;
if (x_ndim == 1 && y_ndim == 1) {
transpose_x = false;
transpose_y = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
out_help = reshape<T>(grad_out, IntArray(std::vector<int64_t>({1, 1})));
} else if (x_ndim == 1) {
transpose_x = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = y;
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else if (y_ndim == 1) {
transpose_y = false;
x_help = x;
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.push_back(1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else {
x_help = x;
y_help = y;
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
out_help = grad_out;
}
bool is_broadcast = true;
if (x_ndim <= 2 && y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
Tensor dx, dy, ddout_1, ddout_2, ddout;
if (!grad_x_grad && !grad_y_grad) {
x_grad = nullptr;
y_grad = nullptr;
grad_out_grad = nullptr;
return;
} else if (!grad_x_grad) {
y_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, false);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, true);
}
}
} else if (!grad_y_grad) {
x_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, false);
}
} else {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, true);
}
}
} else {
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, false);
ddout_2 = matmul<T>(xg_help, y_help, false, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, true);
ddout_2 = matmul<T>(xg_help, y_help, false, true);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, false);
ddout_2 = matmul<T>(xg_help, y_help, true, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, true);
ddout_2 = matmul<T>(xg_help, y_help, true, true);
ddout = add<T>(ddout_1, ddout_2);
}
}
}
if (is_broadcast) {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
// Reduce sum to get grad by ReduceSum
if (x_grad) {
auto tx_dims = x_dims;
auto tx_ndim = x_ndim;
auto tdout_ndim = dout_ndim;
if (x_ndim == 1) {
tx_dims = std::vector<int64_t>({1, x_dims[0]});
tx_ndim = x_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto x_grad_reduce_dims =
get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims);
if (!x_grad_reduce_dims.empty()) {
dx = sum<T>(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dx, IntArray(tx_dims));
}
if (y_grad) {
auto ty_dims = y_dims;
auto ty_ndim = y_ndim;
auto tdout_ndim = dout_ndim;
if (y_ndim == 1) {
ty_dims = std::vector<int64_t>({y_dims[0], 1});
ty_ndim = y_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto y_grad_reduce_dims =
get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims);
if (!y_grad_reduce_dims.empty()) {
dy = sum<T>(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dy, IntArray(ty_dims));
}
}
// recover the original dim of output (delete 1)
std::vector<int64_t> dx_dims =
dx.initialized() ? vectorize(dx.dims()) : std::vector<int64_t>({});
std::vector<int64_t> dy_dims =
dy.initialized() ? vectorize(dy.dims()) : std::vector<int64_t>({});
std::vector<int64_t> ddout_dims =
ddout.initialized() ? vectorize(ddout.dims()) : std::vector<int64_t>({});
if (x_ndim == 1 && y_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims == std::vector<int64_t>({1, 1})) {
ddout = reshape<T>(ddout, IntArray(std::vector<int64_t>({1})));
}
} else if (x_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (ddout.initialized() && ddout_dims[0] == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin() + 1, ddout_dims.cend()})));
}
} else if (y_ndim == 1) {
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims.back() == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin(),
ddout_dims.cbegin() + ddout_dims.size() - 1})));
}
}
if (x_grad) {
set_output<T>(dx, x_grad);
}
if (y_grad) {
set_output<T>(dy, y_grad);
}
if (grad_out_grad) {
set_output<T>(ddout, grad_out_grad);
}
}
template <typename T>
void silu_double_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const Tensor& grad_x_grad,
Tensor* grad_x,
Tensor* grad_out_grad) {
auto sigmoid = out / x;
auto tmp1 = 1 - sigmoid;
auto tmp2 = 1 + tmp1 * x;
if (grad_out_grad) {
auto ddout = grad_x_grad * sigmoid * tmp2;
set_output<T>(ddout, grad_out_grad);
}
if (grad_x) {
auto dx =
sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - x * sigmoid)) * tmp1;
set_output<T>(dx, grad_x);
}
}
} // namespace prim
} // namespace paddle
...@@ -1626,6 +1626,7 @@ ...@@ -1626,6 +1626,7 @@
param : [x] param : [x]
kernel : kernel :
func : silu_grad func : silu_grad
backward : silu_double_grad
composite : silu_grad(x, out, out_grad, x_grad) composite : silu_grad(x, out, out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
...@@ -2106,6 +2107,12 @@ ...@@ -2106,6 +2107,12 @@
func : yolo_loss_grad func : yolo_loss_grad
optional : gt_score optional : gt_score
- backward_op: silu_double_grad
forward: silu_grad (Tensor x, Tensor out, Tensor grad_out) -> Tensor(grad_x)
args: (Tensor x, Tensor out, Tensor grad_out, Tensor grad_x_grad)
output: Tensor(x_grad), Tensor(grad_out_grad)
composite: silu_double_grad(x, out, grad_out, grad_x_grad, x_grad, grad_out_grad)
- backward_op: unpool3d_grad - backward_op: unpool3d_grad
forward: unpool3d (Tensor x, Tensor indices, int[] ksize, int[] strides={1,1,1}, int[] paddings={0,0,0}, int[] output_size={0,0,0}, str data_format="NCDHW") -> Tensor(out) forward: unpool3d (Tensor x, Tensor indices, int[] ksize, int[] strides={1,1,1}, int[] paddings={0,0,0}, int[] output_size={0,0,0}, str data_format="NCDHW") -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] paddings, int[] output_size, str data_format) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] paddings, int[] output_size, str data_format)
......
...@@ -41,6 +41,7 @@ class BaseAPI: ...@@ -41,6 +41,7 @@ class BaseAPI:
) = self.parse_args(self.api, api_item_yaml) ) = self.parse_args(self.api, api_item_yaml)
self.is_base_api = True self.is_base_api = True
self.is_only_composite_api = False
if 'invoke' in api_item_yaml: if 'invoke' in api_item_yaml:
self.is_base_api = False self.is_base_api = False
self.invoke = api_item_yaml['invoke'] self.invoke = api_item_yaml['invoke']
...@@ -49,7 +50,12 @@ class BaseAPI: ...@@ -49,7 +50,12 @@ class BaseAPI:
self.infer_meta = self.parse_infer_meta( self.infer_meta = self.parse_infer_meta(
api_item_yaml['infer_meta'] api_item_yaml['infer_meta']
) )
self.kernel = self.parse_kernel(api_item_yaml['kernel']) if 'composite' in api_item_yaml and 'kernel' not in api_item_yaml:
self.is_base_api = False
self.is_only_composite_api = True
self.kernel = None
else:
self.kernel = self.parse_kernel(api_item_yaml['kernel'])
self.data_transform = self.parse_data_transform(api_item_yaml) self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map, self.view_map = {}, {} self.inplace_map, self.view_map = {}, {}
...@@ -1319,23 +1325,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ ...@@ -1319,23 +1325,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
api_code = "" api_code = ""
api_code = api_code + self.gene_base_api_code(inplace_flag=True) api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code return api_code
elif self.is_only_composite_api:
# for composite and invoke api, dygraph use prim::xxx_grad method
return ''
else: else:
invoke_func_name = self.invoke.split('(')[0].strip() invoke_code = self.invoke
if invoke_func_name in self.attrs['names']: params_code = self.get_define_args()
# Adjust the param whose name is same with api invoked.
pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
def adjust_name(matched):
matched_str = matched.group()
return matched_str[0:-1] + '_val' + matched_str[-1]
invoke_code = re.sub(pattern, adjust_name, self.invoke)
params_code = re.sub(
pattern, adjust_name, self.get_define_args()
)
else:
invoke_code = self.invoke
params_code = self.get_define_args()
return self.gene_invoke_code(invoke_code, params_code) return self.gene_invoke_code(invoke_code, params_code)
...@@ -112,7 +112,7 @@ class BackwardAPI(BaseAPI): ...@@ -112,7 +112,7 @@ class BackwardAPI(BaseAPI):
return "" return ""
def gene_api_declaration(self): def gene_api_declaration(self):
if not self.is_base_api: if not self.is_base_api and not self.is_only_composite_api:
invoke_func_name = self.invoke.split('(')[0] invoke_func_name = self.invoke.split('(')[0]
if (not invoke_func_name.endswith("_grad")) and ( if (not invoke_func_name.endswith("_grad")) and (
not invoke_func_name.endswith('_impl') not invoke_func_name.endswith('_impl')
......
...@@ -2101,7 +2101,7 @@ ...@@ -2101,7 +2101,7 @@
out : Out out : Out
- op : silu - op : silu
backward : silu_grad backward : silu_grad, silu_double_grad
inputs : inputs :
x : X x : X
outputs : outputs :
......
...@@ -331,5 +331,88 @@ class TestMultiplyHighGradCheck(unittest.TestCase): ...@@ -331,5 +331,88 @@ class TestMultiplyHighGradCheck(unittest.TestCase):
self.func_double(p) self.func_double(p)
self.func_triple(p) self.func_triple(p)
''' '''
@param.parameterized_class(
('shape1'),
[
([2],),
([2, 3],),
([2, 3, 4],),
([2, 3, 3, 4],),
],
)
class TestSiluHighGradCheck(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.shape1 = cls.shape1
def silu_wrapper(self, x):
return paddle.nn.functional.silu(x[0])
@prog_scope()
def func_double(self, place):
shape1 = self.shape1
eps = 0.0005
dtype = np.float64
x = paddle.static.data('x', shape1, dtype=dtype)
x.stop_gradient = False
x.persistable = True
out = paddle.nn.functional.silu(x)
x_arr = np.random.uniform(-1, 1, shape1).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
# silu double grad only has CompositeOpMaker,don't need set prim_flag
from paddle.fluid import core
core._set_prim_backward_enabled(True)
gradient_checker.double_grad_check(
[x], y=out, x_init=[x_arr], place=place, eps=eps
)
gradient_checker.double_grad_check_for_dygraph(
self.silu_wrapper,
[x],
y=out,
x_init=[x_arr],
place=place,
)
core._set_prim_backward_enabled(False)
@prog_scope()
def func_triple(self, place):
shape1 = self.shape1
eps = 0.0005
dtype = np.float64
x = paddle.static.data('x', shape1, dtype=dtype)
x.stop_gradient = False
x.persistable = True
out = paddle.nn.functional.silu(x)
x_arr = np.random.uniform(-1, 1, shape1).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
from paddle.fluid import core
core._set_prim_backward_enabled(True)
gradient_checker.triple_grad_check(
[x], y=out, x_init=[x_arr], place=place, eps=eps
)
gradient_checker.triple_grad_check_for_dygraph(
self.silu_wrapper,
[x],
y=out,
x_init=[x_arr],
place=place,
)
core._set_prim_backward_enabled(False)
def test_high_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func_double(p)
self.func_triple(p)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册