未验证 提交 3a881861 编写于 作者: H huzhiqiang 提交者: GitHub

add cuda kernel for sequence_topk_avg_pooling and search_fc (#2451)

* cuda kernel for sequence_topk_avg_pooling and search_fc test=develop
上级 bf2c6fca
......@@ -30,6 +30,8 @@ add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_fc_compute_cuda CUDA basic SRCS search_fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_topk_avg_pooling_compute_cuda CUDA basic SRCS sequence_topk_avg_pooling_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm)
add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
......@@ -53,6 +55,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
nv_test(search_fc_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
if(LITE_BUILD_EXTRA)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_fc_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
static void anakin_NV_gemv(cublasHandle_t handle,
const bool TransA,
const int M,
const int N,
const T alpha,
const T* A,
const T* x,
const T beta,
T* y);
template <>
void anakin_NV_gemv<float>(cublasHandle_t handle,
const bool TransA,
const int M,
const int N,
const float alpha,
const float* A,
const float* x,
const float beta,
float* y) {
LOG(INFO) << "1";
cublasOperation_t cuTransA = (TransA == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBLAS_CHECK(
cublasSgemv(handle, cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1));
}
template <typename T>
static void anakin_NV_gemm(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const T* B,
const T beta,
T* C);
template <>
void anakin_NV_gemm<float>(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const float* B,
const float beta,
float* C) {
LOG(INFO) << "1";
// Note that cublas follows fortran order.
int lda = (!TransA /* == CblasNoTrans*/) ? K : M;
int ldb = (!TransB /* == CblasNoTrans*/) ? N : K;
LOG(INFO) << "1";
cublasOperation_t cuTransA =
(!TransA /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T;
LOG(INFO) << "1";
cublasOperation_t cuTransB =
(!TransB /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T;
LOG(INFO) << "1";
CUBLAS_CHECK(cublasSgemm(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
N));
LOG(INFO) << "1";
}
template <>
void anakin_NV_gemm<char>(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const char alpha,
const char* A,
const char* B,
const char beta,
char* C) {
LOG(FATAL) << "int8 gemm is not implemented";
}
template <typename T>
static __global__ void add_bias(int n,
int output_size,
const T* bias,
T* dout) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int bias_index = index % output_size;
if (index < n) {
dout[index] = dout[index] + bias[bias_index];
}
}
template <typename T>
void SearchFcCompute<T>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const Tensor* x_tensor = param.X;
param.Out->Resize({x_tensor->dims()[0], param.out_size});
_M = x_tensor->dims().count(0, 1);
_K = x_tensor->dims().count(1, x_tensor->numel());
_N = param.out_size;
const T* din = x_tensor->data<T>();
Tensor* out_tensor = param.Out;
T* dout = out_tensor->mutable_data<T>(TARGET(kCUDA));
const Tensor* w_tensor = param.W;
const T* weight = w_tensor->data<T>();
const Tensor* b_tensor = param.b;
const T* bias = b_tensor->data<T>();
cublasCreate(&_handle);
if (_M == 1 && _K > 50000) {
anakin_NV_gemv<T>(_handle, false, _N, _K, (T)1, weight, din, (T)0, dout);
} else {
anakin_NV_gemm<T>(_handle,
false,
!_flag_trans_weights,
_M,
_N,
_K,
(T)1,
din,
weight,
(T)0,
dout);
}
int total_size = _M * _N;
add_bias<T><<<CUDA_GET_BLOCKS(total_size), CUDA_NUM_THREADS, 0, stream>>>(
total_size, _N, bias, dout);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_fc,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchFcCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
const int CUDA_NUM_THREADS = 512;
inline int CUDA_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
inline int CUDA_GET_BLOCKS(const int N, const int base) {
return (N + base - 1) / base;
}
template <typename T>
class SearchFcCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SearchFcParam;
void Run() override;
virtual ~SearchFcCompute() = default;
private:
bool _flag_trans_weights{false};
int _M;
int _K;
int _N;
cublasHandle_t _handle;
bool _is_continue_buf{true};
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void fc_cpu_base(const lite::Tensor* X,
const lite::Tensor* W,
const lite::Tensor* b,
int out_size,
lite::Tensor* Out) {
const float* data_in = X->data<float>();
const float* bias = b->data<float>();
const float* weights = W->data<float>();
float* data_out = Out->mutable_data<float>();
int out_rows = X->dims()[0];
int in_cols = X->numel() / out_rows;
int out_cols = W->numel() / in_cols;
int index_out;
for (int i = 0; i < out_rows; i++) {
for (int j = 0; j < out_cols; j++) {
index_out = i * out_cols + j;
data_out[index_out] = bias ? bias[j] : 0;
for (int k = 0; k < in_cols; k++) {
data_out[index_out] +=
data_in[i * in_cols + k] * weights[j * in_cols + k];
}
}
}
}
TEST(search_fc, normal) {
SearchFcCompute<float> search_fc_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::SearchFcParam param;
lite::Tensor X, X_gpu, W, W_gpu, b, b_gpu;
lite::Tensor Out, Out_cpu, out_ref;
std::vector<int64_t> x_shape{1, 4};
X.Resize(lite::DDim(x_shape));
std::vector<int64_t> w_shape{3, 4};
W.Resize(lite::DDim(w_shape));
std::vector<int64_t> b_shape{3};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{1, 4};
Out.Resize(lite::DDim(out_shape));
out_ref.Resize(lite::DDim(out_shape));
auto x_data = X.mutable_data<float>();
auto w_data = W.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data_ref = out_ref.mutable_data<float>();
for (int64_t i = 0; i < X.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < W.dims().production(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = static_cast<float>(i);
}
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(x_data, X.dims());
W_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(w_data, W.dims());
b_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(b_data, b.dims());
param.X = &X_gpu;
param.W = &W_gpu;
param.b = &b_gpu;
param.out_size = 4;
param.Out = &Out;
search_fc_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
search_fc_kernel.SetContext(std::move(ctx));
search_fc_kernel.Run();
fc_cpu_base(&X, &W, &b, 4, &out_ref);
cudaDeviceSynchronize();
const float* out_data = Out.data<float>();
float* out_cpu_data = Out_cpu.mutable_data<float>();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * Out.numel(), IoDirection::DtoH);
for (int i = 0; i < Out.numel(); ++i) {
EXPECT_NEAR(out_cpu_data[i], out_data_ref[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sequence_topk_avg_pooling_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
__global__ void topk_avg_pooling_kernel_by_row_improve(
Dtype *output_data,
const Dtype *input,
const int *gpu_input_offset_l,
const int *gpu_input_offset_r,
const int row_max,
const int col_max,
const int topk_size,
const int *topks,
const int feat_map_num) {
int row =
gpu_input_offset_l[blockIdx.x + 1] - gpu_input_offset_l[blockIdx.x]; // 8
int col = gpu_input_offset_r[blockIdx.x + 1] -
gpu_input_offset_r[blockIdx.x]; // 30
int max_k = topks[topk_size - 1];
max_k = max_k < col ? max_k : col;
extern __shared__ Dtype smem[]; // H*W
const Dtype *fm_row_in_data = input +
blockIdx.x * row_max * feat_map_num * col_max +
blockIdx.y * row_max * col_max;
for (int i = threadIdx.x; i < row * col_max; i += blockDim.x) {
smem[i] = fm_row_in_data[i];
}
__syncthreads();
for (int idx = threadIdx.x; idx < row; idx += blockDim.x) {
Dtype *fm_row_out_data =
output_data +
(gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size +
blockIdx.y * topk_size;
Dtype *smem_start_col = smem + idx * col_max;
int counter = max_k; // topk_size;
Dtype last_max_val = -20000.0;
while (counter) {
Dtype max_val = -10000.0;
int max_pos = 0;
int m = 0;
for (; m < col; m++) {
Dtype cur_data = smem_start_col[m];
if (cur_data > max_val) {
max_val = cur_data;
max_pos = m;
last_max_val = max_val;
}
}
if (max_val < -9999.0) { // == -10000.0
max_val = last_max_val;
}
smem_start_col[max_pos] = 10000000.0;
int i = max_k - counter;
for (int c = 0; c < topk_size; c++) {
if (i <= topks[c] - 1) {
fm_row_out_data[c] += max_val;
}
}
counter--;
}
__syncthreads();
// compute avg
for (int i = 0; i < topk_size; i++) {
fm_row_out_data[i] = fm_row_out_data[i] / topks[i];
}
}
}
template <typename T>
void SequenceTopkAvgPoolingCompute<T>::Run() {
auto &param = this->Param<param_t>();
auto &ctx = this->ctx_->template As<CUDAContext>();
auto cuda_stream = ctx.exec_stream();
int topk_num = param.topks.size();
lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1});
_top_ks.Resize(top_ks_shape);
cudaMemcpyAsync(_top_ks.mutable_data<int>(TARGET(kCUDA)),
&param.topks[0],
sizeof(int) * topk_num,
cudaMemcpyHostToDevice,
cuda_stream);
int width_offset_len = param.X->lod()[0].size();
lite::DDim width_offset_shape(
std::vector<int64_t>{width_offset_len, 1, 1, 1});
_width_offset.Resize(width_offset_shape);
cudaMemcpyAsync(_width_offset.mutable_data<int>(TARGET(kCUDA)),
&(param.X->lod()[0][0]),
sizeof(int) * width_offset_len,
cudaMemcpyHostToDevice,
cuda_stream);
int height_offset_len = param.ROW->lod()[0].size();
lite::DDim height_offset_shape(
std::vector<int64_t>{height_offset_len, 1, 1, 1});
_height_offset.Resize(height_offset_shape);
cudaMemcpyAsync(_height_offset.mutable_data<int>(TARGET(kCUDA)),
&(param.ROW->lod()[0][0]),
sizeof(int) * height_offset_len,
cudaMemcpyHostToDevice,
cuda_stream);
const Tensor *x_tensor = param.X;
Tensor *out_tensor = param.Out;
const T *in_data = x_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA));
TargetWrapperCuda::MemsetAsync(out_tensor->mutable_data<T>(TARGET(kCUDA)),
0,
sizeof(T) * out_tensor->numel(),
cuda_stream);
auto x_dims = x_tensor->dims();
int num = x_dims[0];
int channel = x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
const int *height_offset = _height_offset.data<int>();
const int *width_offset = _width_offset.data<int>();
int feat_map_size = height * width;
dim3 blocks(num, channel);
dim3 threads(32, 1);
topk_avg_pooling_kernel_by_row_improve<
T><<<blocks, threads, feat_map_size * sizeof(T), cuda_stream>>>(
out_data,
in_data,
height_offset,
width_offset,
height,
width,
param.topks.size(),
_top_ks.data<int>(),
param.channel_num);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
sequence_topk_avg_pooling,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequenceTopkAvgPoolingCompute<float>,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("ROW",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("COLUMN",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("pos",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
class SequenceTopkAvgPoolingCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::SequenceTopkAvgPoolingParam;
void Run() override;
virtual ~SequenceTopkAvgPoolingCompute() = default;
protected:
lite::Tensor _height_offset;
lite::Tensor _width_offset;
lite::Tensor _top_ks;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册