mean_op.cu 4.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
L
liaogang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
L
liaogang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14
#ifdef __NVCC__
W
wangchaochaohu 已提交
15
#include "cub/cub.cuh"
16 17 18 19 20
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
S
sneaxiy 已提交
21 22
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/operators/mean_op.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
25
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduo 已提交
26
#include "paddle/fluid/platform/float16.h"
L
liaogang 已提交
27

W
wangchaochaohu 已提交
28 29 30 31
namespace paddle {
namespace operators {

template <typename T>
32
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
S
sneaxiy 已提交
33
  using MT = typename details::MPTypeTrait<T>::Type;
W
wangchaochaohu 已提交
34
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
S
sneaxiy 已提交
35
  auto data = static_cast<MT>(in_data[0]);
W
wangchaochaohu 已提交
36
  for (; idx < N; idx += blockDim.x * gridDim.x) {
S
sneaxiy 已提交
37
    out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
W
wangchaochaohu 已提交
38 39 40
  }
}

41 42 43 44 45 46 47 48 49
template <typename DeviceContext, typename T>
class MeanCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* input = context.Input<Tensor>("X");
    auto* output = context.Output<Tensor>("Out");

    const T* in_data = input->data<T>();
    T* out_data = output->mutable_data<T>(context.GetPlace());
S
sneaxiy 已提交
50 51 52
    auto numel = input->numel();
    auto rank = input->dims().size();
    auto place = context.GetPlace();
53 54
    auto stream = context.cuda_device_context().stream();

S
sneaxiy 已提交
55 56 57 58 59 60
    if (rank == 0) {  // scalar
      auto gpu_place = BOOST_GET(platform::CUDAPlace, place);
      memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T),
                   stream);
      return;
    }
61

S
sneaxiy 已提交
62 63 64 65 66 67 68 69 70
    using MT = typename details::MPTypeTrait<T>::Type;
    using Div = kernel_primitives::DivideFunctor<T, MT>;
    std::vector<int> reduce_dims;
    reduce_dims.reserve(rank);
    for (decltype(rank) i = 0; i < rank; ++i) {
      reduce_dims.push_back(i);
    }
    TensorReduceFunctorImpl<T, T, kernel_primitives::AddFunctor, Div>(
        *input, output, Div(numel), reduce_dims, stream);
71 72 73
  }
};

W
wangchaochaohu 已提交
74 75 76 77 78
template <typename DeviceContext, typename T>
class MeanCUDAGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
79 80 81 82 83
    PADDLE_ENFORCE_EQ(OG->numel(), 1,
                      platform::errors::InvalidArgument(
                          "Mean Gradient Input Tensor len should be 1. But "
                          "received Out@Grad's elements num is %d.",
                          OG->numel()));
W
wangchaochaohu 已提交
84 85 86
    auto IG = context.Output<Tensor>(framework::GradVarName("X"));
    IG->mutable_data<T>(context.GetPlace());

87
    auto in_data = OG->data<T>();
W
wangchaochaohu 已提交
88 89 90 91 92 93 94 95 96 97 98 99
    auto size_prob = IG->numel();
    auto out_data = IG->data<T>();
    int threads = 512;
    int grid = (size_prob + threads - 1) / threads;
    auto stream = context.cuda_device_context().stream();
    MeanRunKernel<T><<<grid, threads, 0, stream>>>(in_data, out_data,
                                                   size_prob);
  }
};
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
100
namespace ops = paddle::operators;
C
chengduo 已提交
101
namespace plat = paddle::platform;
Q
QI JUN 已提交
102
REGISTER_OP_CUDA_KERNEL(
103 104 105
    mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
    ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
Q
QI JUN 已提交
106
REGISTER_OP_CUDA_KERNEL(
107 108 109 110 111
    mean_grad,
    ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
                            plat::float16>);