elementwise_add_op.h 10.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

F
fengjiayi 已提交
15 16
#pragma once

17 18
#include <algorithm>
#include <utility>
W
Wu Yi 已提交
19
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
20
#include "paddle/fluid/operators/math/blas.h"
21
#include "paddle/fluid/operators/math/math_function.h"
22

23 24 25 26 27 28 29
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"

G
gongweibao 已提交
30 31 32
namespace paddle {
namespace operators {

33
template <typename DeviceContext, typename T>
34 35 36 37
void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx,
                                         const framework::Tensor *x,
                                         const framework::Tensor *y,
                                         framework::Tensor *z) {
38
  int axis = ctx.Attr<int>("axis");
39 40 41
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  if (x_dims.size() >= y_dims.size()) {
42 43 44 45 46 47
    ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
                                                          AddFunctor<T>(), z);
  } else {
    ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
        ctx, x, y, axis, InverseAddFunctor<T>(), z);
  }
48 49
}

50 51 52 53 54 55
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseAdd {
  void operator()(const framework::ExecutionContext &ctx,
                  const framework::Tensor *x, const framework::Tensor *y,
                  framework::Tensor *z);
};
56

Q
QI JUN 已提交
57
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
58
class ElementwiseAddKernel : public framework::OpKernel<T> {
G
gongweibao 已提交
59
 public:
C
chengduo 已提交
60 61 62 63
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *x = ctx.Input<framework::LoDTensor>("X");
    auto *y = ctx.Input<framework::LoDTensor>("Y");
    auto *z = ctx.Output<framework::LoDTensor>("Out");
C
chengduoZH 已提交
64
    z->mutable_data<T>(ctx.GetPlace());
65 66 67 68 69 70

    auto &dev_ctx = ctx.device_context<DeviceContext>();
    int axis = ctx.Attr<int>("axis");
    auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
    auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
    auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
71
    pten::Add<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get());
G
gongweibao 已提交
72 73 74 75
  }
};

template <typename T>
Y
Yu Yang 已提交
76 77
struct IdentityGrad {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
G
gongweibao 已提交
78 79
};

80
template <typename DeviceContext, typename T>
81 82 83 84 85 86 87 88
typename std::enable_if<
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y,
                             const framework::Tensor *out,
                             const framework::Tensor *dout,
                             framework::Tensor *dx, framework::Tensor *dy) {
89 90
  int axis = ctx.Attr<int>("axis");

91 92 93 94
  ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
                              IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
                                               dx, dy, IdentityGrad<T>(),
                                               IdentityGrad<T>());
95 96
}

97
template <typename DeviceContext, typename T>
98 99 100
typename std::enable_if<
    std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
101 102 103 104 105
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
106 107 108 109 110 111 112 113 114 115 116 117
  auto blas = math::GetBlas<DeviceContext, T>(ctx);
  if (dx) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dx->mutable_data<T>(ctx.GetPlace()));
  }

  if (dy) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dy->mutable_data<T>(ctx.GetPlace()));
  }
}

118
template <typename DeviceContext, typename T>
119
typename std::enable_if<
120 121
    !std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
122 123 124 125 126 127
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
  default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
128 129
}

130
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
131 132 133 134
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
135 136 137 138 139
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy);
140 141 142 143 144 145 146 147 148 149

template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y,
                             const framework::Tensor *out,
                             const framework::Tensor *dout,
                             framework::Tensor *dx, framework::Tensor *dy);
150 151
#endif

Q
QI JUN 已提交
152
template <typename DeviceContext, typename T>
153
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
G
gongweibao 已提交
154
 public:
C
chengduo 已提交
155
  void Compute(const framework::ExecutionContext &ctx) const override {
156 157
    ElemwiseGradKernel<T>::Compute(ctx);

C
chengduoZH 已提交
158 159
    using Tensor = framework::Tensor;

160 161
    auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Input<Tensor>("Y");
C
chengduo 已提交
162 163 164
    auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
165
    // skip out
C
chengduo 已提交
166
    auto *out = dout;
167

168 169 170 171 172 173 174 175 176 177 178 179 180 181
    // Special case when dy is not needed and dx doesn't reduce
    if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) {
      VLOG(4) << "Special case when dy is not needed and dx doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dx);
    } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) {
      VLOG(4) << "Special case when dx is not needed and dy doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dy);
    } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
182
      elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
183
    } else {
184 185
      default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
                                                     dy);
186
    }
G
gongweibao 已提交
187 188 189
  }
};

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
template <typename DeviceContext, typename T>
class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;

    auto *y = ctx.Input<Tensor>("Y");
    auto *dout = ctx.Input<Tensor>("DOut");
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");

    auto *ddout = ctx.Output<Tensor>("DDOut");

    // ddOut = ddx + ddy
    if (ddout) {
      Tensor ddx_safe, ddy_safe;
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

      ddout->mutable_data<T>(ctx.GetPlace());
210 211
      LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, &ddx_safe,
                                                            &ddy_safe, ddout);
212 213 214 215
    }
  }
};

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
template <typename DeviceContext, typename T>
class ElementwiseAddTripleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");
    auto *d_ddout = ctx.Input<Tensor>("D_DDOut");
    auto *d_ddx = ctx.Output<Tensor>("D_DDX");
    auto *d_ddy = ctx.Output<Tensor>("D_DDY");
    // skip out
    auto *out = d_ddout;

    // Special case when d_ddy is not needed and d_ddx doesn't reduce
    if (d_ddx != nullptr && d_ddy == nullptr &&
        d_ddx->dims() == d_ddout->dims()) {
      VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't "
                 "reduce";
      framework::TensorCopy(
          *d_ddout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), d_ddx);
    } else if (d_ddx == nullptr && d_ddy != nullptr &&
               d_ddy->dims() == d_ddout->dims()) {
      VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't "
                 "reduce";
      framework::TensorCopy(
          *d_ddout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), d_ddy);
    } else if (d_ddx != nullptr && d_ddy != nullptr &&
               (d_ddx->dims() == d_ddy->dims())) {
      elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out, d_ddout, d_ddx,
                                             d_ddy);
    } else {
      default_elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out,
                                                     d_ddout, d_ddx, d_ddy);
    }
  }
};

G
gongweibao 已提交
255 256
}  // namespace operators
}  // namespace paddle