未验证 提交 c642aa17 编写于 作者: X xiaoguoguo626807 提交者: GitHub

add_triple_grad rules (#54164)

上级 94a56cc1
...@@ -67,6 +67,7 @@ black_ops_list = [ ...@@ -67,6 +67,7 @@ black_ops_list = [
prim_white_list = [ prim_white_list = [
"matmul_double_grad", "matmul_double_grad",
"subtract_double_grad", "subtract_double_grad",
"add_triple_grad",
"silu_double_grad", "silu_double_grad",
] ]
......
...@@ -154,6 +154,44 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -154,6 +154,44 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class ElementwiseAddCompositeTripleGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// get input
paddle::Tensor ddx = this->GetSingleForwardInput("DDX");
paddle::Tensor ddy = this->GetSingleForwardInput("DDY");
paddle::Tensor d_ddout = this->GetSingleOutputGrad("DDOut");
// get output
paddle::Tensor grad_grad_x_t =
this->GetSingleInputGrad(framework::GradVarName("DDX"));
paddle::Tensor grad_grad_y_t =
this->GetSingleInputGrad(framework::GradVarName("DDY"));
// get attr
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument("We only support axis = -1 in composite "
"add_triple_grad but we got: ",
axis));
paddle::Tensor* grad_grad_x = this->GetOutputPtr(&grad_grad_x_t);
std::string grad_grad_x_name = this->GetOutputName(grad_grad_x_t);
paddle::Tensor* grad_grad_y = this->GetOutputPtr(&grad_grad_y_t);
std::string grad_grad_y_name = this->GetOutputName(grad_grad_y_t);
VLOG(6) << "Runing add_triple_grad composite func";
prim::add_triple_grad<prim::DescTensor>(
ddx, ddy, d_ddout, axis, grad_grad_x, grad_grad_y);
this->RecoverOutputName(grad_grad_x_t, grad_grad_x_name);
this->RecoverOutputName(grad_grad_y_t, grad_grad_y_name);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -234,30 +234,6 @@ void subtract_grad(const Tensor& x, ...@@ -234,30 +234,6 @@ void subtract_grad(const Tensor& x,
} }
} }
template <typename T>
void subtract_double_grad(const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* grad_out_grad) {
if (grad_out_grad) {
// ddout = ddx - ddy
if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr;
} else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) {
ddout = ddout + grad_x_grad.get();
}
if (grad_y_grad) {
ddout = ddout - grad_y_grad.get();
}
set_output<T>(ddout, grad_out_grad);
}
}
}
template <typename T> template <typename T>
void add_grad(const Tensor& x, void add_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
...@@ -300,30 +276,6 @@ void add_grad(const Tensor& x, ...@@ -300,30 +276,6 @@ void add_grad(const Tensor& x,
} }
} }
template <typename T>
void add_double_grad(const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* grad_out_grad) {
if (grad_out_grad) {
// ddout = ddx + ddy
if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr;
} else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) {
ddout = ddout + grad_x_grad.get();
}
if (grad_y_grad) {
ddout = ddout + grad_y_grad.get();
}
set_output<T>(ddout, grad_out_grad);
}
}
}
template <typename T> template <typename T>
void sum_grad(const Tensor& x, void sum_grad(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
...@@ -555,75 +507,6 @@ void multiply_grad(const Tensor& x, ...@@ -555,75 +507,6 @@ void multiply_grad(const Tensor& x,
} }
} }
template <typename T>
void multiply_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
if (x_grad) {
if (grad_y_grad) {
auto dx = grad_y_grad.get() * grad_out;
if (dx.dims() != x.dims()) {
auto axes = get_reduce_dims_from_out(dx.dims(), x.dims());
if (!axes.size()) {
set_output<T>(dx, x_grad);
} else {
auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false);
if (dx_reduce.dims().size() != x.dims().size()) {
dx_reduce = reshape<T>(dx_reduce, x.shape());
}
set_output<T>(dx_reduce, x_grad);
}
} else {
set_output<T>(dx, x_grad);
}
} else {
x_grad = nullptr;
}
}
if (y_grad) {
if (grad_x_grad) {
auto dy = grad_x_grad.get() * grad_out;
if (dy.dims() != y.dims()) {
auto axes = get_reduce_dims_from_out(dy.dims(), y.dims());
if (!axes.size()) {
set_output<T>(dy, y_grad);
} else {
auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false);
if (dy_reduce.dims().size() != y.dims().size()) {
dy_reduce = reshape<T>(dy_reduce, y.shape());
}
set_output<T>(dy_reduce, y_grad);
}
} else {
set_output<T>(dy, y_grad);
}
} else {
y_grad = nullptr;
}
}
if (grad_out_grad) {
if (grad_x_grad && grad_y_grad) {
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else if (grad_x_grad) {
auto ddout = grad_x_grad.get() * y;
set_output<T>(ddout, grad_out_grad);
} else if (grad_y_grad) {
auto ddout = grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else {
grad_out_grad = nullptr;
}
}
}
template <typename T> template <typename T>
void expand_grad(const Tensor& x, void expand_grad(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
......
...@@ -383,5 +383,175 @@ void silu_double_grad(const Tensor& x, ...@@ -383,5 +383,175 @@ void silu_double_grad(const Tensor& x,
} }
} }
template <typename T>
void multiply_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
if (x_grad) {
if (grad_y_grad) {
auto dx = grad_y_grad.get() * grad_out;
if (dx.dims() != x.dims()) {
auto axes = get_reduce_dims_from_out(dx.dims(), x.dims());
if (!axes.size()) {
set_output<T>(dx, x_grad);
} else {
auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false);
if (dx_reduce.dims().size() != x.dims().size()) {
dx_reduce = reshape<T>(dx_reduce, x.shape());
}
set_output<T>(dx_reduce, x_grad);
}
} else {
set_output<T>(dx, x_grad);
}
} else {
x_grad = nullptr;
}
}
if (y_grad) {
if (grad_x_grad) {
auto dy = grad_x_grad.get() * grad_out;
if (dy.dims() != y.dims()) {
auto axes = get_reduce_dims_from_out(dy.dims(), y.dims());
if (!axes.size()) {
set_output<T>(dy, y_grad);
} else {
auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false);
if (dy_reduce.dims().size() != y.dims().size()) {
dy_reduce = reshape<T>(dy_reduce, y.shape());
}
set_output<T>(dy_reduce, y_grad);
}
} else {
set_output<T>(dy, y_grad);
}
} else {
y_grad = nullptr;
}
}
if (grad_out_grad) {
if (grad_x_grad && grad_y_grad) {
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else if (grad_x_grad) {
auto ddout = grad_x_grad.get() * y;
set_output<T>(ddout, grad_out_grad);
} else if (grad_y_grad) {
auto ddout = grad_y_grad.get() * x;
set_output<T>(ddout, grad_out_grad);
} else {
grad_out_grad = nullptr;
}
}
}
template <typename T>
void add_double_grad(const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* grad_out_grad) {
if (grad_out_grad) {
// ddout = ddx + ddy
if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr;
} else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) {
ddout = ddout + grad_x_grad.get();
}
if (grad_y_grad) {
ddout = ddout + grad_y_grad.get();
}
set_output<T>(ddout, grad_out_grad);
}
}
}
template <typename T>
void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
const paddle::optional<Tensor>& grad_grad_y,
const Tensor& grad_grad_out_grad,
int axis,
Tensor* grad_grad_x_grad,
Tensor* grad_grad_y_grad) {
if (grad_grad_y_grad) {
if (grad_grad_y) {
if (grad_grad_y.get().dims() != grad_grad_out_grad.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(grad_grad_y.get().dims(),
grad_grad_out_grad.dims());
if (!reduce_dim.size()) {
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
} else {
auto dddy_reduce_res = grad_grad_out_grad.sum(
phi::vectorize(reduce_dim), grad_grad_y.get().dtype(), false);
auto dddy_tmp = reshape<T>(dddy_reduce_res,
phi::vectorize(grad_grad_y.get().dims()));
set_output<T>(dddy_tmp, grad_grad_y_grad);
}
} else {
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
}
} else {
grad_grad_y_grad = nullptr;
}
}
if (grad_grad_x_grad) {
if (grad_grad_x) {
if (grad_grad_x.get().dims() != grad_grad_out_grad.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(grad_grad_x.get().dims(),
grad_grad_out_grad.dims());
if (!reduce_dim.size()) {
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
} else {
auto dddx_reduce_res = grad_grad_out_grad.sum(
phi::vectorize(reduce_dim), grad_grad_x.get().dtype(), false);
auto dddx_tmp = reshape<T>(dddx_reduce_res,
phi::vectorize(grad_grad_x.get().dims()));
set_output<T>(dddx_tmp, grad_grad_x_grad);
}
} else {
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
}
} else {
grad_grad_x_grad = nullptr;
}
}
}
template <typename T>
void subtract_double_grad(const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
int axis,
Tensor* grad_out_grad) {
if (grad_out_grad) {
// ddout = ddx - ddy
if (!grad_x_grad && !grad_y_grad) {
grad_out_grad = nullptr;
} else {
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
if (grad_x_grad) {
ddout = ddout + grad_x_grad.get();
}
if (grad_y_grad) {
ddout = ddout - grad_y_grad.get();
}
set_output<T>(ddout, grad_out_grad);
}
}
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
kernel : kernel :
func : add_triple_grad func : add_triple_grad
inplace : (grad_grad_out_grad -> grad_grad_x_grad) inplace : (grad_grad_out_grad -> grad_grad_x_grad)
composite : add_triple_grad (grad_grad_x, grad_grad_y, grad_grad_out_grad, axis, grad_grad_x_grad, grad_grad_y_grad )
- backward_op : amax_grad - backward_op : amax_grad
forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out) forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册