sequence_pool_compute.cu 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
16
#include "lite/backends/cuda/cuda_utils.h"
17 18 19 20 21 22 23 24 25
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pool_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

26 27 28 29
#define CUDA_KERNEL_LOOP(i, n)                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
template <typename Dtype>
__global__ void seq_pool_average_kernel(Dtype* dst,
                                        const Dtype* src_in,
                                        const int batch_size,
                                        const uint64_t* seq_offset,
                                        const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
                                        seq_offset[out_batch_id]);
    int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
    src_in += in_offset + out_id;
    Dtype sum = (Dtype)0;
    for (int i = 0; i < in_slice_num; ++i) {
      sum += src_in[i * slice_size];
    }
    dst[out_batch_id * slice_size + out_id] = sum / in_slice_num;
  }
}

template <typename Dtype>
__global__ void seq_pool_sum_kernel(Dtype* dst,
                                    const Dtype* src_in,
                                    const int batch_size,
                                    const uint64_t* seq_offset,
                                    const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
                                        seq_offset[out_batch_id]);
    int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
    src_in += in_offset + out_id;
    Dtype sum = (Dtype)0;
    for (int i = 0; i < in_slice_num; ++i) {
      sum += src_in[i * slice_size];
    }
    dst[out_batch_id * slice_size + out_id] = sum;
  }
}

template <typename Dtype>
__global__ void seq_pool_sqrt_kernel(Dtype* dst,
                                     const Dtype* src_in,
                                     const int batch_size,
                                     const uint64_t* seq_offset,
                                     const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
                                        seq_offset[out_batch_id]);
    int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
    src_in += in_offset + out_id;
    Dtype sum = (Dtype)0;
    for (int i = 0; i < in_slice_num; ++i) {
      sum += src_in[i * slice_size];
    }
    dst[out_batch_id * slice_size + out_id] = sum * rsqrtf(in_slice_num);
  }
}

template <typename Dtype>
__global__ void seq_pool_max_kernel(Dtype* dst,
                                    const Dtype* src_in,
                                    const int batch_size,
                                    const uint64_t* seq_offset,
                                    const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
                                        seq_offset[out_batch_id]);
    int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
    src_in += in_offset + out_id;
    Dtype max = src_in[0];
    for (int i = 1; i < in_slice_num; ++i) {
      Dtype val = src_in[i * slice_size];
      if (val > max) {
        max = val;
      }
    }
    dst[out_batch_id * slice_size + out_id] = max;
  }
}

template <typename Dtype>
__global__ void seq_pool_last_kernel(Dtype* dst,
                                     const Dtype* src_in,
                                     const int batch_size,
                                     const uint64_t* seq_offset,
                                     const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_offset =
        (static_cast<int>(seq_offset[out_batch_id + 1]) - 1) * slice_size;
    dst[tid] = src_in[in_offset + out_id];
  }
}

template <typename Dtype>
__global__ void seq_pool_first_kernel(Dtype* dst,
                                      const Dtype* src_in,
                                      const int batch_size,
                                      const uint64_t* seq_offset,
                                      const int slice_size) {
  int total = slice_size * batch_size;
  CUDA_KERNEL_LOOP(tid, total) {
    int out_batch_id = tid / slice_size;
    int out_id = tid % slice_size;
    int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
    dst[tid] = src_in[in_offset + out_id];
  }
}

void SequencePoolCompute::Run() {
  auto& param = this->Param<param_t>();
  auto& ctx = this->ctx_->template As<CUDAContext>();
  auto stream = ctx.exec_stream();

  std::vector<uint64_t> seq_offset = param.X->lod()[0];
P
Pei Yang 已提交
158 159
  int batch_size = param.X->lod()[0].size() - 1;
  int slice_size = param.Out->dims().production() / batch_size;
160 161 162 163 164

  float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
  const float* in_data = param.X->data<float>();

  seq_offset_D.Resize({static_cast<int64_t>(seq_offset.size())});
P
Pei Yang 已提交
165 166 167 168 169 170
  TargetWrapperCuda::MemcpyAsync(
      seq_offset_D.mutable_data<uint64_t>(TARGET(kCUDA)),
      seq_offset.data(),
      sizeof(uint64_t) * seq_offset.size(),
      IoDirection::HtoD,
      stream);
171 172 173 174 175 176 177 178 179 180

  if (param.pool_type == "MAX") {
    seq_pool_max_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                 CUDA_NUM_THREADS,
                                 0,
                                 stream>>>(out_data,
                                           in_data,
                                           batch_size,
                                           seq_offset_D.data<uint64_t>(),
                                           slice_size);
P
Pei Yang 已提交
181
  } else if (param.pool_type == "AVERAGE") {
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    seq_pool_average_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                     CUDA_NUM_THREADS,
                                     0,
                                     stream>>>(out_data,
                                               in_data,
                                               batch_size,
                                               seq_offset_D.data<uint64_t>(),
                                               slice_size);
  } else if (param.pool_type == "SUM") {
    seq_pool_sum_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                 CUDA_NUM_THREADS,
                                 0,
                                 stream>>>(out_data,
                                           in_data,
                                           batch_size,
                                           seq_offset_D.data<uint64_t>(),
                                           slice_size);
  } else if (param.pool_type == "SQRT") {
    seq_pool_sqrt_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                  CUDA_NUM_THREADS,
                                  0,
                                  stream>>>(out_data,
                                            in_data,
                                            batch_size,
                                            seq_offset_D.data<uint64_t>(),
                                            slice_size);
  } else if (param.pool_type == "FIRST") {
    seq_pool_first_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                   CUDA_NUM_THREADS,
                                   0,
                                   stream>>>(out_data,
                                             in_data,
                                             batch_size,
                                             seq_offset_D.data<uint64_t>(),
                                             slice_size);
  } else if (param.pool_type == "LAST") {
    seq_pool_last_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
                                  CUDA_NUM_THREADS,
                                  0,
                                  stream>>>(out_data,
                                            in_data,
                                            batch_size,
                                            seq_offset_D.data<uint64_t>(),
                                            slice_size);
  } else {
    LOG(ERROR) << "pool type " << param.pool_type << " is not supoorted.";
  }

  std::vector<uint64_t> offset_new(static_cast<uint64_t>(batch_size + 1));

  for (int i = 0; i <= batch_size; ++i) {
    offset_new[i] = i;
  }
  std::vector<std::vector<uint64_t>> voffset_new;
  voffset_new.push_back(offset_new);
  param.Out->set_lod(voffset_new);

  cudaError_t error = cudaGetLastError();
  if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(sequence_pool,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::SequencePoolCompute,
                     def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
256
    .BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kCUDA))})
257
    .Finalize();