elementwise_add_op.h 6.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
F
fengjiayi 已提交
14 15
#pragma once

W
Wu Yi 已提交
16
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
17
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
W
Wu Yi 已提交
18
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
19
#include "paddle/fluid/operators/math/blas.h"
G
gongweibao 已提交
20 21 22
namespace paddle {
namespace operators {

23
template <typename DeviceContext, typename T>
C
chengduo 已提交
24 25 26
void default_elementwise_add(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y, framework::Tensor *z) {
27 28 29 30 31
  int axis = ctx.Attr<int>("axis");
  ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
                                                        AddFunctor<T>(), z);
}

32 33 34 35 36 37
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseAdd {
  void operator()(const framework::ExecutionContext &ctx,
                  const framework::Tensor *x, const framework::Tensor *y,
                  framework::Tensor *z);
};
38

Q
QI JUN 已提交
39
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
40
class ElementwiseAddKernel : public framework::OpKernel<T> {
G
gongweibao 已提交
41
 public:
C
chengduo 已提交
42 43 44 45
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *x = ctx.Input<framework::LoDTensor>("X");
    auto *y = ctx.Input<framework::LoDTensor>("Y");
    auto *z = ctx.Output<framework::LoDTensor>("Out");
C
chengduoZH 已提交
46
    z->mutable_data<T>(ctx.GetPlace());
47
    auto dims_equal = x->dims() == y->dims();
48
    if (dims_equal) {
49 50
      SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
      same_dims_add(ctx, x, y, z);
51
    } else {
52
      default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
53
    }
G
gongweibao 已提交
54 55 56 57
  }
};

template <typename T>
Y
Yu Yang 已提交
58 59
struct IdentityGrad {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
G
gongweibao 已提交
60 61
};

62
template <typename DeviceContext, typename T>
C
chengduo 已提交
63 64 65 66 67 68 69
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                                  const framework::Tensor *x,
                                  const framework::Tensor *y,
                                  const framework::Tensor *out,
                                  const framework::Tensor *dout,
                                  framework::Tensor *dx,
                                  framework::Tensor *dy) {
70 71
  int axis = ctx.Attr<int>("axis");

72 73 74 75
  ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
                              IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
                                               dx, dy, IdentityGrad<T>(),
                                               IdentityGrad<T>());
76 77
}

78
template <typename DeviceContext, typename T>
79 80 81
typename std::enable_if<
    std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
C
chengduo 已提交
82 83 84 85 86
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
87 88 89 90 91 92 93 94 95 96 97 98
  auto blas = math::GetBlas<DeviceContext, T>(ctx);
  if (dx) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dx->mutable_data<T>(ctx.GetPlace()));
  }

  if (dy) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dy->mutable_data<T>(ctx.GetPlace()));
  }
}

99
template <typename DeviceContext, typename T>
100
typename std::enable_if<
101 102
    !std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
C
chengduo 已提交
103 104 105 106 107
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
108 109 110
  default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}

111 112 113 114 115 116 117 118 119 120 121 122
#ifdef PADDLE_WITH_CUDA
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy);
#endif

Q
QI JUN 已提交
123
template <typename DeviceContext, typename T>
124
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
G
gongweibao 已提交
125
 public:
C
chengduo 已提交
126
  void Compute(const framework::ExecutionContext &ctx) const override {
127 128
    ElemwiseGradKernel<T>::Compute(ctx);

C
chengduoZH 已提交
129 130
    using Tensor = framework::Tensor;

C
chengduo 已提交
131 132 133
    auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
T
tensor-tang 已提交
134
    // skip out, x, y
C
chengduo 已提交
135
    auto *out = dout;
T
tensor-tang 已提交
136
    auto *x = dout, *y = dout;
137

138
    if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
T
tensor-tang 已提交
139
      elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
140
    } else {
T
tensor-tang 已提交
141 142
      default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
                                                     dy);
143
    }
G
gongweibao 已提交
144 145 146
  }
};

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <typename DeviceContext, typename T>
class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;

    auto *y = ctx.Input<Tensor>("Y");
    auto *dout = ctx.Input<Tensor>("DOut");
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");

    auto *ddout = ctx.Output<Tensor>("DDOut");

    // ddOut = ddx + ddy
    if (ddout) {
      Tensor ddx_safe, ddy_safe;
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

      ddout->mutable_data<T>(ctx.GetPlace());
      default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
                                                ddout);
    }
  }
};

G
gongweibao 已提交
173 174
}  // namespace operators
}  // namespace paddle