elementwise_sub_op.h 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

F
fengjiayi 已提交
15
#pragma once
16

W
Wu Yi 已提交
17
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
18
#include "paddle/fluid/platform/place.h"
G
gongweibao 已提交
19

20
#include "paddle/pten/kernels/math_kernel.h"
G
gongweibao 已提交
21 22 23
namespace paddle {
namespace operators {

24 25 26 27 28
template <typename DeviceContext, typename T>
void default_elementwise_sub(const framework::ExecutionContext& ctx,
                             const framework::Tensor* x,
                             const framework::Tensor* y, framework::Tensor* z) {
  int axis = ctx.Attr<int>("axis");
29 30 31
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  if (x_dims.size() >= y_dims.size()) {
32 33 34 35 36 37
    ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
                                                          SubFunctor<T>(), z);
  } else {
    ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
        ctx, x, y, axis, InverseSubFunctor<T>(), z);
  }
38 39
}

Q
QI JUN 已提交
40
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
41
class ElementwiseSubKernel : public framework::OpKernel<T> {
G
gongweibao 已提交
42 43
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
C
chengduo 已提交
44 45 46
    auto* x = ctx.Input<framework::LoDTensor>("X");
    auto* y = ctx.Input<framework::LoDTensor>("Y");
    auto* z = ctx.Output<framework::LoDTensor>("Out");
C
chengduoZH 已提交
47
    z->mutable_data<T>(ctx.GetPlace());
48

49 50 51 52 53
    auto& dev_ctx = ctx.device_context<DeviceContext>();
    int axis = ctx.Attr<int>("axis");
    auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
    auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
    auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
54 55
    pten::SubtractKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
                            pt_z.get());
G
gongweibao 已提交
56 57 58 59
  }
};

template <typename T>
C
chengduoZH 已提交
60 61
struct SubGradDX {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
G
gongweibao 已提交
62 63 64
};

template <typename T>
C
chengduoZH 已提交
65 66
struct SubGradDY {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
G
gongweibao 已提交
67 68
};

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
                             const framework::Tensor* x,
                             const framework::Tensor* y,
                             const framework::Tensor* out,
                             const framework::Tensor* dout,
                             framework::Tensor* dx, framework::Tensor* dy) {
  int axis = ctx.Attr<int>("axis");

  ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
      ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}

84 85 86 87 88 89 90 91
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
92
  default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
93 94
}

95
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
96
// cuda definition
97 98 99 100 101 102 103 104 105 106
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
                             const framework::Tensor* x,
                             const framework::Tensor* y,
                             const framework::Tensor* out,
                             const framework::Tensor* dout,
                             framework::Tensor* dx, framework::Tensor* dy);

107 108 109 110 111 112 113 114 115 116
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy);
#endif

Q
QI JUN 已提交
117
template <typename DeviceContext, typename T>
118
class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
G
gongweibao 已提交
119 120
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
121
    ElemwiseGradKernel<T>::Compute(ctx);
C
chengduoZH 已提交
122 123
    using Tensor = framework::Tensor;

124 125
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Input<Tensor>("Y");
C
chengduoZH 已提交
126 127 128
    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
129
    // skip out
130
    auto* out = dout;
131 132 133
    if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
      elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
    } else {
134 135
      default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
                                                     dy);
136
    }
G
gongweibao 已提交
137 138
  }
};
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

template <typename DeviceContext, typename T>
class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    using Tensor = framework::Tensor;

    auto* y = ctx.Input<Tensor>("Y");
    auto* dout = ctx.Input<Tensor>("DOut");
    auto* ddx = ctx.Input<Tensor>("DDX");
    auto* ddy = ctx.Input<Tensor>("DDY");

    auto* ddout = ctx.Output<Tensor>("DDOut");

    // DDOut = ddx - ddy
    if (ddout) {
      Tensor ddx_safe, ddy_safe;
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

      ddout->mutable_data<T>(ctx.GetPlace());
      int axis = ctx.Attr<int>("axis");
      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
          ctx, &ddx_safe, &ddy_safe, axis, SubFunctor<T>(), ddout);
    }
  }
};

G
gongweibao 已提交
167 168
}  // namespace operators
}  // namespace paddle