softmax_compute.cu 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/softmax_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;

W
Wilber 已提交
24 25
const int CUDA_NUM_THREADS = 512;

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
extern __shared__ char tile[];
template <typename dtype>
__global__ void sharemem_softmax_kernel(int total_size,
                                        const dtype* in_data,
                                        dtype* out_data,
                                        int inner_num,
                                        int outer_num,
                                        int axis_size) {
  dtype* data = reinterpret_cast<dtype*>(tile) + threadIdx.x;
  //! compute thread index and real data index
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < total_size) {
    int idx_inner = idx % inner_num;
    int idx_outer = (idx / inner_num) * axis_size;
    int blocksize = blockDim.x;
    int real_index = idx_outer * inner_num + idx_inner;
    int loop_idx = real_index;
//! read all data to sharemem in softmax channel
#pragma unroll
    for (int i = 0; i < axis_size; ++i) {
      data[i * blocksize] = in_data[loop_idx];
      loop_idx += inner_num;
    }
    //! get maximum value in softmax channel
    dtype max_data = data[0];
#pragma unroll
    for (int i = 1; i < axis_size; ++i) {
      dtype dt = data[i * blocksize];
      if (max_data < dt) {
        max_data = dt;
      }
    }
    //! subtract then summarize
    dtype sum = 0;
#pragma unroll
    for (int i = 0; i < axis_size; ++i) {
      dtype* dt = data + i * blocksize;
      *dt = expf(*dt - max_data);
      sum += *dt;
    }
    //! write back result
    loop_idx = real_index;
#pragma unroll
    for (int i = 0; i < axis_size; ++i) {
      out_data[loop_idx] = data[i * blocksize] / sum;
      loop_idx += inner_num;
    }
  }
}

//! general kernel for softmax
template <typename dtype>
__global__ void softmax_max_kernel(int total_size,
                                   const dtype* in_data,
                                   dtype* out_data,
                                   dtype min_data,
                                   int inner_num,
                                   int outer_num,
                                   int axis_size) {
  //! compute data index
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < total_size) {
    int idx_inner = idx % inner_num;
    int idx_outer = (idx / inner_num) * axis_size;
    int real_index = idx_outer * inner_num + idx_inner;
    //! get maximum data across softmax axis
    dtype max_data = min_data;
    for (int i = 0; i < axis_size; ++i) {
      max_data =
          in_data[real_index] > max_data ? in_data[real_index] : max_data;
      real_index += inner_num;
    }
    out_data[idx] = max_data;
  }
}

template <typename dtype>
__global__ void softmax_sub_exp_sum_kernel(int total_size,
                                           const dtype* in_data,
                                           dtype* out_data,
                                           const dtype* max_data,
                                           dtype* sum_data,
                                           int inner_num,
                                           int outer_num,
                                           int axis_size) {
  //! compute data index
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < total_size) {
    int idx_inner = idx % inner_num;
    int idx_outer = (idx / inner_num) * axis_size;

    dtype max_data_cur = max_data[idx];
    dtype sum_data_cur = 0;
    int real_index = idx_outer * inner_num + idx_inner;
    //! compute exp and summarize across the softmax axis
    for (int i = 0; i < axis_size; ++i) {
      dtype sub_data = in_data[real_index] - max_data_cur;
      sub_data = expf(sub_data);
      sum_data_cur += sub_data;
      out_data[real_index] = sub_data;
      real_index += inner_num;
    }
    sum_data[idx] = sum_data_cur;
  }
}

template <typename dtype>
__global__ void softmax_divid_output_kernel(int total_size,
                                            dtype* io_data,
                                            const dtype* sum_data,
                                            int inner_num,
                                            int outer_num,
                                            int axis_size) {
  //! compute data index
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < total_size) {
    int idx_inner = idx % inner_num;
    int idx_outer = (idx / inner_num) * axis_size;
    dtype sum_data_cur = 1.f / sum_data[idx];
    int real_index = idx_outer * inner_num + idx_inner;
    //! compute final result
    for (int i = 0; i < axis_size; ++i) {
      io_data[real_index] = io_data[real_index] * sum_data_cur;
      real_index += inner_num;
    }
  }
}

W
Wilber 已提交
154 155 156 157 158
void SoftmaxCompute::PrepareForRun() {
  int device_id;
  cudaGetDevice(&device_id);
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, device_id);
159 160
  sharedmem_size_ = deviceProp.sharedMemPerBlock;
  max_dimsize_ = sharedmem_size_ / sizeof(float) / CUDA_NUM_THREADS;
W
Wilber 已提交
161 162
}

163 164 165 166 167 168 169 170 171 172 173 174 175 176
void SoftmaxCompute::Run() {
  auto& param = this->Param<param_t>();
  auto& ctx = this->ctx_->template As<CUDAContext>();
  auto stream = ctx.exec_stream();

  auto x_dims = param.x->dims();
  auto x_rank = x_dims.size();
  int axis = param.axis;
  if (axis < 0) {
    axis += x_rank;
  }
  int outer_num = x_dims.Slice(0, axis).production();
  int inner_num = x_dims.Slice(axis + 1, x_rank).production();
  int total_threads = inner_num * outer_num;
177
  axis_size_ = x_dims[axis];
178

W
Wilber 已提交
179
  const int threads = CUDA_NUM_THREADS;
180 181 182
  const int blocks = (total_threads + threads - 1) / threads;
  auto input_data = param.x->data<float>();
  auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
183 184
  if (axis_size_ <= max_dimsize_) {
    int use_sharemem_size = axis_size_ * threads * sizeof(float);
185 186 187 188 189 190
    sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
        total_threads,
        input_data,
        output_data,
        inner_num,
        outer_num,
191
        axis_size_);
192 193
  } else {
    //! re_alloc device memory
194 195 196 197
    tmax_data_.Resize({1, 1, 1, outer_num * inner_num});
    tsum_data_.Resize({1, 1, 1, outer_num * inner_num});
    auto max_data = tmax_data_.mutable_data<float>(TARGET(kCUDA));
    auto sum_data = tsum_data_.mutable_data<float>(TARGET(kCUDA));
198
    //! firstly, get maximum data
199
    float min_data = std::numeric_limits<float>::lowest();
200 201 202 203 204 205
    softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads,
                                                              input_data,
                                                              max_data,
                                                              min_data,
                                                              inner_num,
                                                              outer_num,
206
                                                              axis_size_);
207 208 209 210 211 212 213 214 215
    //! then, compute exp and sum data
    softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>(
        total_threads,
        input_data,
        output_data,
        max_data,
        sum_data,
        inner_num,
        outer_num,
216
        axis_size_);
217 218
    //! last, compute divided output
    softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>(
219
        total_threads, output_data, sum_data, inner_num, outer_num, axis_size_);
220 221
  }
  cudaError_t error = cudaGetLastError();
222
  if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(softmax,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::SoftmaxCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("axis",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();
249 250 251 252 253 254 255 256 257 258 259 260 261 262
REGISTER_LITE_KERNEL(search_seq_softmax,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::SoftmaxCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
263
    .BindOutput("Out_log", {LiteType::GetTensorTy(TARGET(kCUDA))})
264
    .Finalize();