提交 127d2664 编写于 作者: W wangruting

merge develop

上级 184fa04c
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/common/int_array.h"
......@@ -23,16 +24,17 @@ namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// using IntArray = paddle::experimental::IntArray;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
set_output<T>(grad_x_tmp, grad_x);
set_output<T>(grad_x_tmp.impl(), grad_x);
}
template <typename T>
void subtract_grad(const Tensor& x,
const Tensor& y,
......@@ -42,26 +44,33 @@ void subtract_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
by_pass<T>(scale_out_grad, dy);
} else {
auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
}
} else {
by_pass<T>(scale_out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
}
} else {
by_pass<T>(out_grad, dx);
}
......@@ -76,25 +85,34 @@ void add_grad(const Tensor& x,
Tensor* dx,
Tensor* dy) {
if (dy) {
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dy);
} else {
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
}
} else {
by_pass<T>(out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
}
} else {
by_pass<T>(out_grad, dx);
}
......@@ -131,11 +149,12 @@ void sum_grad(const Tensor& x,
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, x_dim);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, x_dim);
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
set_output<T>(x_grad_tmp, x_grad);
set_output<T>(x_grad_tmp.impl(), x_grad);
}
template <typename T>
......@@ -152,15 +171,19 @@ void divide_grad(const Tensor& x,
auto tmp1 = divide<T>(x, tmp0);
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = multiply<T>(tmp2, out_grad);
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res.impl(), dy);
} else {
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
}
} else {
set_output<T>(dy_res, dy);
set_output<T>(dy_res.impl(), dy);
}
} // indicate we will compute dy
if (dx) {
......@@ -168,15 +191,20 @@ void divide_grad(const Tensor& x,
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res.impl(), dx);
} else {
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
}
} else {
set_output<T>(dx_res, dx);
set_output<T>(dx_res.impl(), dx);
}
} // indicate we will compute dx
}
......@@ -187,8 +215,89 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
set_output<T>(x_grad_tmp, x_grad);
set_output<T>(x_grad_tmp.impl(), x_grad);
}
}
template <typename T>
void multiply_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
auto x_grad_unreduce = multiply<T>(out_grad, y);
if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) {
set_output<T>(x_grad_unreduce.impl(), x_grad);
} else {
auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes),
x_grad_unreduce.dtype(),
false);
if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
set_output<T>(x_grad_reduced.impl(), x_grad);
}
} else {
set_output<T>(x_grad_unreduce.impl(), x_grad);
}
}
if (y_grad) {
auto y_grad_unreduce = multiply<T>(out_grad, x);
if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) {
set_output<T>(y_grad_unreduce.impl(), y_grad);
} else {
auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes),
y_grad_unreduce.dtype(),
false);
if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
set_output<T>(y_grad_reduced.impl(), y_grad);
}
} else {
set_output<T>(y_grad_unreduce.impl(), y_grad);
}
}
}
template <typename T>
void expand_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shape,
Tensor* x_grad) {
if (x_grad) {
auto out_dims = phi::make_ddim(shape.GetData());
if (out_dims != x.dims()) {
auto axes = get_reduce_dims(x.dims(), out_dims);
if (!axes.size()) {
by_pass<T>(out_grad, x_grad);
} else {
auto reduced = sum<T>(out_grad, phi::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced.impl(), x_grad);
}
} else {
by_pass<T>(out_grad, x_grad);
}
}
}
template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
set_output<T>(multiply<T>(out_grad, out).impl(), x_grad);
}
}
} // namespace prim
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册