提交 af95a0bb 编写于 作者: W Wilber 提交者: GitHub

fix bug for crmm model test=develop (#2786)

- modify aligned_matmul kernel for dynamically malloc memory
- fix top_k_avg_pooling kernel to support data whose size is more than cuda shared memory.
上级 a5d89422
......@@ -31,21 +31,12 @@ class SearchAlignedMatMulCompute
using param_t = operators::MatMulParam;
void PrepareForRun() override {
auto& param = this->Param<param_t>();
CHECK(ctx_) << "running context should be set first";
auto& cuda_ctx = ctx_->template As<CUDAContext>();
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
int seq_num = param.X->lod()[0].size() - 1;
batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm<float, float>);
CHECK(
batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx));
A_ = static_cast<float**>(malloc(3 * seq_num * sizeof(float*)));
CHECK(A_);
}
void Run() override {
auto& param = this->Param<param_t>();
auto& cuda_ctx = ctx_->template As<CUDAContext>();
auto x = param.X;
auto y = param.Y;
auto out = param.Out;
......@@ -76,25 +67,25 @@ class SearchAlignedMatMulCompute
auto x_stride = x_batch_size * x_inner_size;
auto y_stride = y_batch_size * y_inner_size;
auto out_stride = M * N;
for (int seq = 0; seq < seq_num; seq++) {
float* A_[seq_num * 3];
for (int seq = 0; seq < seq_num; ++seq) {
A_[seq] = const_cast<float*>(x_data) + seq * x_stride;
A_[seq + seq_num] = const_cast<float*>(y_data) + seq * y_stride;
A_[seq + seq_num * 2] = out_data + seq * out_stride;
}
CHECK(
batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx));
batched_gemm_impl_->run(
alpha, 0.0f, const_cast<const float**>(A_), M, N, K, seq_num);
}
~SearchAlignedMatMulCompute() {
if (A_ != nullptr) {
free(A_);
}
}
~SearchAlignedMatMulCompute() { batched_gemm_impl_.reset(); }
private:
std::unique_ptr<lite::cuda::math::BatchedGemm<float, float>>
batched_gemm_impl_;
float** A_{nullptr};
};
} // namespace cuda
......
......@@ -93,6 +93,111 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
}
}
template <typename Dtype>
__global__ void topk_avg_pooling_kernel_for_big_data(
Dtype *output_data,
const Dtype *input_data,
const int *gpu_input_offset_l,
const int *gpu_input_offset_r,
const int row_max,
const int col_max,
const int topk_size,
const int *topks,
const int feat_map_num,
const int actual_row_in_shared_mem) {
int row = gpu_input_offset_l[blockIdx.x + 1] -
gpu_input_offset_l[blockIdx.x]; // 75
int col = gpu_input_offset_r[blockIdx.x + 1] -
gpu_input_offset_r[blockIdx.x]; // 300
int max_k = topks[topk_size - 1];
max_k = max_k < col ? max_k : col;
extern __shared__ Dtype smem[]; // H1*W or H2*W ...
int filled_z = row / actual_row_in_shared_mem;
int remain_row = row - filled_z * actual_row_in_shared_mem;
if (blockIdx.z > filled_z || (blockIdx.z == filled_z && remain_row == 0)) {
return;
}
const Dtype *fm_row_in_data = input_data +
blockIdx.x * row_max * feat_map_num * col_max +
blockIdx.y * row_max * col_max +
blockIdx.z * actual_row_in_shared_mem * col_max;
if (blockIdx.z == filled_z) {
for (int i = threadIdx.x; i < remain_row * col_max; i += blockDim.x) {
smem[i] = fm_row_in_data[i];
}
} else {
for (int i = threadIdx.x; i < actual_row_in_shared_mem * col_max;
i += blockDim.x) {
smem[i] = fm_row_in_data[i];
}
}
__syncthreads();
int cur_row;
if (blockIdx.z == filled_z) {
cur_row = remain_row;
} else {
cur_row = actual_row_in_shared_mem;
}
for (int idx = threadIdx.x; idx < cur_row; idx += blockDim.x) {
Dtype *fm_row_out_data = output_data +
(gpu_input_offset_l[blockIdx.x] +
blockIdx.z * actual_row_in_shared_mem + idx) *
feat_map_num * topk_size +
blockIdx.y * topk_size;
Dtype *smem_start_col = smem + idx * col_max;
int counter = max_k; // topk_size;
Dtype last_max_val = -20000.0;
while (counter) {
Dtype max_val = -10000.0;
int max_pos = 0; // -1;
int m = 0;
for (; m < col; m++) {
Dtype cur_data = smem_start_col[m];
if (cur_data > max_val) {
max_val = cur_data;
max_pos = m;
last_max_val = max_val;
}
}
if (max_val < -9999.0) { // == -10000.0
max_val = last_max_val;
}
smem_start_col[max_pos] = -10000000.0;
int i = max_k - counter;
for (int c = 0; c < topk_size; c++) {
if (i <= topks[c] - 1) {
fm_row_out_data[c] += max_val;
}
}
counter--;
}
__syncthreads();
// compute avg
for (int i = 0; i < topk_size; i++) {
fm_row_out_data[i] = fm_row_out_data[i] / topks[i];
}
}
}
template <typename T>
void SequenceTopkAvgPoolingCompute<T>::PrepareForRun() {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id);
_shared_mem_size = deviceProp.sharedMemPerBlock;
}
template <typename T>
void SequenceTopkAvgPoolingCompute<T>::Run() {
auto &param = this->Param<param_t>();
......@@ -156,20 +261,40 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
int feat_map_size = height * width;
dim3 blocks(num, channel);
dim3 threads(32, 1);
topk_avg_pooling_kernel_by_row_improve<
T><<<blocks, threads, feat_map_size * sizeof(T), cuda_stream>>>(
out_data,
in_data,
height_offset,
width_offset,
height,
width,
param.topks.size(),
_top_ks.data<int>(),
param.channel_num);
if (feat_map_size * sizeof(T) <= _shared_mem_size) {
dim3 blocks(num, channel);
dim3 threads(32, 1);
topk_avg_pooling_kernel_by_row_improve<
T><<<blocks, threads, feat_map_size * sizeof(T), cuda_stream>>>(
out_data,
in_data,
height_offset,
width_offset,
height,
width,
param.topks.size(),
_top_ks.data<int>(),
param.channel_num);
} else {
int actual_row = _shared_mem_size / width / sizeof(T);
int num_z = (height + actual_row - 1) / actual_row;
dim3 blocks(num, channel, num_z);
dim3 threads(32, 1);
topk_avg_pooling_kernel_for_big_data<
T><<<blocks, threads, actual_row * width * sizeof(T), cuda_stream>>>(
out_data,
in_data,
height_offset,
width_offset,
height,
width,
param.topks.size(),
_top_ks.data<int>(),
param.channel_num,
actual_row);
}
}
} // namespace cuda
......
......@@ -29,12 +29,15 @@ class SequenceTopkAvgPoolingCompute
void Run() override;
void PrepareForRun() override;
virtual ~SequenceTopkAvgPoolingCompute() = default;
protected:
lite::Tensor _height_offset;
lite::Tensor _width_offset;
lite::Tensor _top_ks;
int _shared_mem_size;
};
} // namespace cuda
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册