sequence_softmax_op.cu 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <cub/cub.cuh>  // NOLINT
W
Wu Yi 已提交
17
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
18 19 20 21 22 23

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

24 25 26
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
__device__ __forceinline__ double real_exp(double x) { return exp(x); }

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
template <typename T, int BlockDim>
using BlockReduce = cub::BlockReduce<T, BlockDim>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

template <typename T, int BlockDim>
__global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
                                        const size_t src_hight, T *out_data) {
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
  __shared__ T shared_max_data;
  __shared__ T shared_sum_data;

  for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
    size_t start = ref_lod[i];
    size_t span = ref_lod[i + 1] - start;

    // Find the max ele
    T max_ele = -FLT_MAX;
    for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
      T ele = in_data[start + tid];
      max_ele = max_ele > ele ? max_ele : ele;
    }
    max_ele =
        BlockReduce<T, BlockDim>(temp_storage).Reduce(max_ele, cub::Max());
    if (threadIdx.x == 0) {
      shared_max_data = max_ele;
    }
    __syncthreads();

    // sum
    T sum_data = 0;
    for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
      T ele = in_data[start + tid];
      sum_data += real_exp(ele - shared_max_data);
    }
    sum_data =
        BlockReduce<T, BlockDim>(temp_storage).Reduce(sum_data, cub::Sum());
    if (threadIdx.x == 0) {
      shared_sum_data = sum_data;
    }
    __syncthreads();

    // get final resit
    for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
      T ele = in_data[start + tid];
      ele = real_exp(ele - shared_max_data) / shared_sum_data;
      out_data[start + tid] = ele;
    }
  }
}

template <typename T, int BlockDim>
__global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data,
                                             const T *softmax_data,
                                             const size_t *ref_lod,
                                             const size_t src_hight,
                                             T *dx_data) {
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
  __shared__ T shared_data;

  for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
    size_t start = ref_lod[i];
    size_t span = ref_lod[i + 1] - start;

    T result = 0;
    for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
      size_t idx = start + tid;
      T s_g_d = softmax_grad_data[idx];
      T s_d = softmax_data[idx];
      result += s_g_d * s_d;
    }
    result = BlockReduce<T, BlockDim>(temp_storage).Reduce(result, cub::Sum());
    if (threadIdx.x == 0) {
      shared_data = result;
    }
    __syncthreads();

    for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
      size_t idx = start + tid;
      T s_g_d = softmax_grad_data[idx];
      T s_d = softmax_data[idx];
      dx_data[idx] = (s_g_d - shared_data) * s_d;
    }
  }
}

template <typename T>
struct SequenceSoftmaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext &context,
                  const LoDTensor &x,
                  const framework::Vector<size_t> &ref_lod, /*referenced lod*/
                  LoDTensor *out) {
    int hight = ref_lod.size() - 1;

    const int kThreadsPerBlock = 32;
    int thread_x = kThreadsPerBlock;
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    dim3 block_size(thread_x);
    dim3 grid_size(max_blocks);
    sequence_softmax_kernel<
        T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
        x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight,
        out->mutable_data<T>(context.GetPlace()));
  }
};

template <typename T>
struct SequenceSoftmaxGradFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext &context,
                  const LoDTensor &dout, const LoDTensor &out,
                  const framework::Vector<size_t> &ref_lod, /*referenced lod*/
                  LoDTensor *dx) {
    size_t hight = ref_lod.size() - 1;

    const int kThreadsPerBlock = 32;
    int thread_x = kThreadsPerBlock;
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    dim3 block_size(thread_x);
    dim3 grid_size(max_blocks);

    sequence_softmax_grad_kernel<
        T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
        dout.data<T>(), out.data<T>(), ref_lod.CUDAData(context.GetPlace()),
        hight, dx->mutable_data<T>(context.GetPlace()));
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    sequence_softmax,
    ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    sequence_softmax_grad,
    ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
                                   double>);