elementwise_div_op.cu 4.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
G
gongweibao 已提交
16 17

namespace ops = paddle::operators;
18 19 20 21 22 23 24
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
typename std::enable_if<
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
                   const framework::Tensor* x, const framework::Tensor* y,
                   const framework::Tensor* out, const framework::Tensor* dout,
                   framework::Tensor* dx, framework::Tensor* dy) {
  int axis = ctx.Attr<int>("axis");
  const auto& dev_ctx = ctx.template device_context<DeviceContext>();
  const auto place = ctx.GetPlace();
  if (dx != nullptr && dy != nullptr) {
    dx->mutable_data<T>(place);
    if (dx->IsSharedBufferWith(*dout)) {
      dx->clear();
      dx->mutable_data<T>(x->dims(), place);
    }
    std::vector<const framework::Tensor*> ins = {dout, out, y};
    GetGradXAndYOut<ElementwiseType::kTernary, T>(
        dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor<T, T>());
  } else if (dx != nullptr && dy == nullptr) {
    dx->mutable_data<T>(place);
    if (dx->IsSharedBufferWith(*dout)) {
      dx->clear();
      dx->mutable_data<T>(x->dims(), place);
    }
    std::vector<const framework::Tensor*> ins = {dout, y};
    GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
                                                dx, DivGradXFunctor<T>());
  } else if (dy != nullptr && dx == nullptr) {
    std::vector<const framework::Tensor*> ins = {dout, out, y};
    GetGradXOrYOut<ElementwiseType::kTernary, T>(
        dev_ctx, place, axis, ins, dout, dy, DivGradYFunctor<T>());
  }
56 57 58 59
}

}  // namespace operators
}  // namespace paddle
G
gongweibao 已提交
60

Q
QI JUN 已提交
61
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
62
    elementwise_div,
Q
QI JUN 已提交
63
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
W
Wu Yi 已提交
64 65
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
                              paddle::platform::float16>,
Q
QI JUN 已提交
66 67
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
68 69
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
70
                              paddle::platform::complex<float>>,
71
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
72
                              paddle::platform::complex<double>>);
Q
QI JUN 已提交
73
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
74
    elementwise_div_grad,
Q
QI JUN 已提交
75
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
W
Wu Yi 已提交
76 77
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
                                  paddle::platform::float16>,
Q
QI JUN 已提交
78 79
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
80 81
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
82
                                  paddle::platform::complex<float>>,
Q
QI JUN 已提交
83
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
84
                                  paddle::platform::complex<double>>);
85 86 87 88
REGISTER_OP_CUDA_KERNEL(
    elementwise_div_grad_grad,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        float>,
89 90
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        paddle::platform::float16>,
91 92 93 94 95
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        double>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        int>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
96 97
                                        int64_t>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
98
                                        paddle::platform::complex<float>>,
99
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
100
                                        paddle::platform::complex<double>>);