elementwise_add_op.h 10.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

F
fengjiayi 已提交
15 16
#pragma once

17 18
#include <algorithm>
#include <utility>
W
Wu Yi 已提交
19
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
20 21

// only can include the headers in paddle/pten/include dirs
22
#include "paddle/pten/kernels/math_kernel.h"
23

G
gongweibao 已提交
24 25 26
namespace paddle {
namespace operators {

27
template <typename DeviceContext, typename T>
28 29 30 31
void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx,
                                         const framework::Tensor *x,
                                         const framework::Tensor *y,
                                         framework::Tensor *z) {
32
  int axis = ctx.Attr<int>("axis");
33 34 35
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  if (x_dims.size() >= y_dims.size()) {
36 37 38 39 40 41
    ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
                                                          AddFunctor<T>(), z);
  } else {
    ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
        ctx, x, y, axis, InverseAddFunctor<T>(), z);
  }
42 43
}

44 45 46 47 48 49
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseAdd {
  void operator()(const framework::ExecutionContext &ctx,
                  const framework::Tensor *x, const framework::Tensor *y,
                  framework::Tensor *z);
};
50

Q
QI JUN 已提交
51
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
52
class ElementwiseAddKernel : public framework::OpKernel<T> {
G
gongweibao 已提交
53
 public:
C
chengduo 已提交
54 55 56 57
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *x = ctx.Input<framework::LoDTensor>("X");
    auto *y = ctx.Input<framework::LoDTensor>("Y");
    auto *z = ctx.Output<framework::LoDTensor>("Out");
C
chengduoZH 已提交
58
    z->mutable_data<T>(ctx.GetPlace());
59 60 61 62 63 64

    auto &dev_ctx = ctx.device_context<DeviceContext>();
    int axis = ctx.Attr<int>("axis");
    auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
    auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
    auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
65
    pten::AddKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get());
G
gongweibao 已提交
66 67 68 69
  }
};

template <typename T>
Y
Yu Yang 已提交
70 71
struct IdentityGrad {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
G
gongweibao 已提交
72 73
};

74
template <typename DeviceContext, typename T>
75 76 77 78 79 80 81 82
typename std::enable_if<
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y,
                             const framework::Tensor *out,
                             const framework::Tensor *dout,
                             framework::Tensor *dx, framework::Tensor *dy) {
83 84
  int axis = ctx.Attr<int>("axis");

85 86 87 88
  ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
                              IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
                                               dx, dy, IdentityGrad<T>(),
                                               IdentityGrad<T>());
89 90
}

91
template <typename DeviceContext, typename T>
92 93 94
typename std::enable_if<
    std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
95 96 97 98 99
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
100 101 102 103 104 105 106 107 108 109 110 111
  auto blas = math::GetBlas<DeviceContext, T>(ctx);
  if (dx) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dx->mutable_data<T>(ctx.GetPlace()));
  }

  if (dy) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dy->mutable_data<T>(ctx.GetPlace()));
  }
}

112
template <typename DeviceContext, typename T>
113
typename std::enable_if<
114 115
    !std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
116 117 118 119 120 121
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
  default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
122 123
}

124
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
125 126 127 128
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
129 130 131 132 133
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy);
134 135 136 137 138 139 140 141 142 143

template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y,
                             const framework::Tensor *out,
                             const framework::Tensor *dout,
                             framework::Tensor *dx, framework::Tensor *dy);
144 145
#endif

Q
QI JUN 已提交
146
template <typename DeviceContext, typename T>
147
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
G
gongweibao 已提交
148
 public:
C
chengduo 已提交
149
  void Compute(const framework::ExecutionContext &ctx) const override {
150 151
    ElemwiseGradKernel<T>::Compute(ctx);

C
chengduoZH 已提交
152 153
    using Tensor = framework::Tensor;

154 155
    auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Input<Tensor>("Y");
C
chengduo 已提交
156 157 158
    auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
159
    // skip out
C
chengduo 已提交
160
    auto *out = dout;
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175
    // Special case when dy is not needed and dx doesn't reduce
    if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) {
      VLOG(4) << "Special case when dy is not needed and dx doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dx);
    } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) {
      VLOG(4) << "Special case when dx is not needed and dy doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dy);
    } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
176
      elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
177
    } else {
178 179
      default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
                                                     dy);
180
    }
G
gongweibao 已提交
181 182 183
  }
};

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
template <typename DeviceContext, typename T>
class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;

    auto *y = ctx.Input<Tensor>("Y");
    auto *dout = ctx.Input<Tensor>("DOut");
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");

    auto *ddout = ctx.Output<Tensor>("DDOut");

    // ddOut = ddx + ddy
    if (ddout) {
      Tensor ddx_safe, ddy_safe;
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

      ddout->mutable_data<T>(ctx.GetPlace());
204 205
      LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, &ddx_safe,
                                                            &ddy_safe, ddout);
206 207 208 209
    }
  }
};

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
template <typename DeviceContext, typename T>
class ElementwiseAddTripleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");
    auto *d_ddout = ctx.Input<Tensor>("D_DDOut");
    auto *d_ddx = ctx.Output<Tensor>("D_DDX");
    auto *d_ddy = ctx.Output<Tensor>("D_DDY");
    // skip out
    auto *out = d_ddout;

    // Special case when d_ddy is not needed and d_ddx doesn't reduce
    if (d_ddx != nullptr && d_ddy == nullptr &&
        d_ddx->dims() == d_ddout->dims()) {
      VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't "
                 "reduce";
      framework::TensorCopy(
          *d_ddout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), d_ddx);
    } else if (d_ddx == nullptr && d_ddy != nullptr &&
               d_ddy->dims() == d_ddout->dims()) {
      VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't "
                 "reduce";
      framework::TensorCopy(
          *d_ddout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), d_ddy);
    } else if (d_ddx != nullptr && d_ddy != nullptr &&
               (d_ddx->dims() == d_ddy->dims())) {
      elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out, d_ddout, d_ddx,
                                             d_ddy);
    } else {
      default_elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out,
                                                     d_ddout, d_ddx, d_ddy);
    }
  }
};

G
gongweibao 已提交
249 250
}  // namespace operators
}  // namespace paddle