sequence_topk_avg_pooling_compute.cu 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sequence_topk_avg_pooling_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

template <typename Dtype>
__global__ void topk_avg_pooling_kernel_by_row_improve(
    Dtype *output_data,
    const Dtype *input,
    const int *gpu_input_offset_l,
    const int *gpu_input_offset_r,
W
Wilber 已提交
29 30
    const int row_max,
    const int col_max,
31 32 33 34 35 36 37
    const int topk_size,
    const int *topks,
    const int feat_map_num) {
  int row =
      gpu_input_offset_l[blockIdx.x + 1] - gpu_input_offset_l[blockIdx.x];  // 8
  int col = gpu_input_offset_r[blockIdx.x + 1] -
            gpu_input_offset_r[blockIdx.x];  // 30
W
Wilber 已提交
38

39 40 41 42 43
  int max_k = topks[topk_size - 1];
  max_k = max_k < col ? max_k : col;

  extern __shared__ Dtype smem[];  // H*W

W
Wilber 已提交
44 45 46
  const Dtype *fm_row_in_data = input +
                                blockIdx.x * row_max * feat_map_num * col_max +
                                blockIdx.y * row_max * col_max;
47

W
Wilber 已提交
48
  for (int i = threadIdx.x; i < row * col_max; i += blockDim.x) {
49 50 51 52 53 54 55 56 57 58
    smem[i] = fm_row_in_data[i];
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < row; idx += blockDim.x) {
    Dtype *fm_row_out_data =
        output_data +
        (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size +
        blockIdx.y * topk_size;

W
Wilber 已提交
59
    Dtype *smem_start_col = smem + idx * col_max;
60 61 62 63 64

    int counter = max_k;  // topk_size;
    Dtype last_max_val = -20000.0;
    while (counter) {
      Dtype max_val = -10000.0;
W
Wilber 已提交
65
      int max_pos = 0;  // -1;
66 67 68 69 70 71 72 73 74 75 76 77
      int m = 0;
      for (; m < col; m++) {
        Dtype cur_data = smem_start_col[m];
        if (cur_data > max_val) {
          max_val = cur_data;
          max_pos = m;
          last_max_val = max_val;
        }
      }
      if (max_val < -9999.0) {  // == -10000.0
        max_val = last_max_val;
      }
78
      smem_start_col[max_pos] = -10000000.0;
W
Wilber 已提交
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
      int i = max_k - counter;
      for (int c = 0; c < topk_size; c++) {
        if (i <= topks[c] - 1) {
          fm_row_out_data[c] += max_val;
        }
      }
      counter--;
    }
    __syncthreads();
    // compute avg
    for (int i = 0; i < topk_size; i++) {
      fm_row_out_data[i] = fm_row_out_data[i] / topks[i];
    }
  }
}

template <typename T>
void SequenceTopkAvgPoolingCompute<T>::Run() {
  auto &param = this->Param<param_t>();
  auto &ctx = this->ctx_->template As<CUDAContext>();
  auto cuda_stream = ctx.exec_stream();

W
Wilber 已提交
102 103 104 105 106 107 108
  CHECK(param.X->lod().size() > 0 && param.X->lod()[0].size() > 0)
      << "X sequence offset is not valid";
  CHECK(param.ROW->lod().size() > 0 && param.ROW->lod()[0].size() > 0)
      << "ROW sequence offset is not valid";

  int width_offset_len = param.X->lod()[0].size();
  lite::DDim width_offset_shape(std::vector<int64_t>{width_offset_len});
109
  _width_offset.Resize(width_offset_shape);
110
  std::vector<int> width_lod_0(width_offset_len, 0);
W
Wilber 已提交
111 112
  for (size_t i = 0; i < param.X->lod()[0].size(); ++i) {
    width_lod_0[i] = static_cast<int>(param.X->lod()[0][i]);
113
  }
114
  cudaMemcpyAsync(_width_offset.mutable_data<int>(TARGET(kCUDA)),
115
                  &width_lod_0[0],
116 117 118 119 120
                  sizeof(int) * width_offset_len,
                  cudaMemcpyHostToDevice,
                  cuda_stream);

  int height_offset_len = param.ROW->lod()[0].size();
W
Wilber 已提交
121
  lite::DDim height_offset_shape(std::vector<int64_t>{height_offset_len});
122
  _height_offset.Resize(height_offset_shape);
123 124 125 126
  std::vector<int> height_lod_0(height_offset_len, 0);
  for (size_t i = 0; i < param.ROW->lod()[0].size(); ++i) {
    height_lod_0[i] = static_cast<int>(param.ROW->lod()[0][i]);
  }
127
  cudaMemcpyAsync(_height_offset.mutable_data<int>(TARGET(kCUDA)),
128
                  &height_lod_0[0],
129 130 131 132 133 134 135 136
                  sizeof(int) * height_offset_len,
                  cudaMemcpyHostToDevice,
                  cuda_stream);

  const Tensor *x_tensor = param.X;
  Tensor *out_tensor = param.Out;
  const T *in_data = x_tensor->data<T>();
  T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA));
W
Wilber 已提交
137 138 139 140 141 142 143 144 145 146 147
  TargetWrapperCuda::MemsetAsync(
      out_data, 0, sizeof(T) * param.Out->numel(), cuda_stream);

  int topk_num = param.topks.size();
  lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1});
  _top_ks.Resize(top_ks_shape);
  cudaMemcpyAsync(_top_ks.mutable_data<int>(TARGET(kCUDA)),
                  &param.topks[0],
                  sizeof(int) * topk_num,
                  cudaMemcpyHostToDevice,
                  cuda_stream);
148

W
Wilber 已提交
149 150 151 152
  int num = param.X->dims()[0];
  int channel = param.X->dims()[1];
  int height = param.X->dims()[2];
  int width = param.X->dims()[3];
153 154 155 156

  const int *height_offset = _height_offset.data<int>();
  const int *width_offset = _width_offset.data<int>();

W
Wilber 已提交
157 158
  int feat_map_size = height * width;

159 160
  dim3 blocks(num, channel);
  dim3 threads(32, 1);
W
Wilber 已提交
161

162 163 164 165 166 167
  topk_avg_pooling_kernel_by_row_improve<
      T><<<blocks, threads, feat_map_size * sizeof(T), cuda_stream>>>(
      out_data,
      in_data,
      height_offset,
      width_offset,
W
Wilber 已提交
168 169
      height,
      width,
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
      param.topks.size(),
      _top_ks.data<int>(),
      param.channel_num);
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(
    sequence_topk_avg_pooling,
    kCUDA,
    kFloat,
    kNCHW,
    paddle::lite::kernels::cuda::SequenceTopkAvgPoolingCompute<float>,
    def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("ROW",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("COLUMN",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .BindOutput("pos",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();