math_function.cu 6.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14
#include <vector>
Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/data_type.h"
Y
Yu Yang 已提交
16
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
19
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
20

Q
qijun 已提交
21 22 23 24
namespace paddle {
namespace operators {
namespace math {

25 26
using float16 = paddle::platform::float16;

K
Kexin Zhao 已提交
27
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
Q
QI JUN 已提交
28 29 30 31 32
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
33

Q
qingqing01 已提交
34 35 36 37 38
#define DEFINE_GPU_TRANS(RANK)                                           \
  template struct Transpose<platform::CUDADeviceContext, float, RANK>;   \
  template struct Transpose<platform::CUDADeviceContext, double, RANK>;  \
  template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
  template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>;
39 40 41 42 43 44 45

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
Q
qijun 已提交
46

47 48
struct TensorSetConstantGPU {
  TensorSetConstantGPU(const platform::DeviceContext& context,
D
dangqingqing 已提交
49
                       framework::Tensor* tensor, float value)
50 51 52
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
D
dzhwinter 已提交
53
  void apply() const {
Q
QI JUN 已提交
54 55 56
    SetConstant<platform::CUDADeviceContext, T> functor;
    functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
            tensor_, static_cast<T>(value_));
57 58 59 60 61 62 63 64
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

template <>
D
dzhwinter 已提交
65
void set_constant_with_place<platform::CUDAPlace>(
66 67
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
Y
Yu Yang 已提交
68
  framework::VisitDataType(tensor->type(),
69
                           TensorSetConstantGPU(context, tensor, value));
70 71
}

Q
qingqing01 已提交
72
template <typename T>
Q
qingqing01 已提交
73 74 75
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
                                 int num) {
  T tmp = 1.0 / width;
Q
qingqing01 已提交
76 77
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
Q
qingqing01 已提交
78 79 80
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
Q
qingqing01 已提交
81 82 83 84 85 86 87 88 89
  }
}

template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& vector, framework::Tensor* output) {
    auto in_dims = input.dims();
Q
qingqing01 已提交
90 91 92
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(vector.numel(), size);
    PADDLE_ENFORCE_EQ(output->dims(), in_dims);
Q
qingqing01 已提交
93 94 95
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
Q
qingqing01 已提交
96 97
        input.data<T>(), vector.data<T>(), output->data<T>(),
        static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
Q
qingqing01 已提交
98 99 100
  }
};

Q
QI JUN 已提交
101 102 103
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
Y
yangyaming 已提交
104 105
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
Q
QI JUN 已提交
106 107
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
108 109
// and only failed for this case. So reimplemented it.
template <>
Q
QI JUN 已提交
110 111
void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
112 113 114 115 116 117
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), size);
  framework::Tensor one;
  one.mutable_data<double>({in_dims[0]}, context.GetPlace());
Q
QI JUN 已提交
118
  SetConstant<platform::CUDADeviceContext, double> set;
119
  set(context, &one, static_cast<double>(1.0));
Y
Yu Yang 已提交
120 121 122
  GetBlas<platform::CUDADeviceContext, double>(context).GEMV(
      true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]), 1.0,
      input.data<double>(), one.data<double>(), 0.0, vector->data<double>());
123
}
124

C
chengduoZH 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
  framework::Tensor one;
  one.mutable_data<double>({size}, context.GetPlace());
  SetConstant<platform::CUDADeviceContext, double> set;
  set(context, &one, static_cast<double>(1.0));
Y
Yu Yang 已提交
141 142 143
  GetBlas<platform::CUDADeviceContext, double>(context).GEMV(
      true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]), 1.0,
      one.data<double>(), input.data<double>(), 0.0, vector->data<double>());
C
chengduoZH 已提交
144 145 146 147 148
}

template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;

Q
qijun 已提交
149 150 151
}  // namespace math
}  // namespace operators
}  // namespace paddle