未验证 提交 8373aec5 编写于 作者: H hong19860320 提交者: GitHub

[LITE][CUDA] Add CUDA kernel for search_aligned_mat_mul and search_seq_fc Op (#2449)

上级 f6aba39d
......@@ -56,6 +56,15 @@
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << CudnnGetErrorInfo(status); \
}
const int CUDA_NUM_THREADS = 512;
// CUDA: number of blocks for threads.
inline int CUDA_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
inline int CUDA_GET_BLOCKS(const int N, const int base) {
return (N + base - 1) / base;
}
namespace paddle {
namespace lite {
namespace cuda {
......
......@@ -13,6 +13,7 @@ nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -23,6 +24,7 @@ set (
cuda_transpose
cuda_elementwise
cuda_gemm
cuda_batched_gemm
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/batched_gemm.h"
#include <iostream>
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
bool BatchedGemm<float, float>::init(const bool trans_a,
const bool trans_b,
const int max_batch_size,
Context<TARGET(kCUDA)> *ctx) {
if (cu_handle_ == nullptr) {
this->exe_stream_ = ctx->exec_stream();
CUBLAS_CALL(cublasCreate(&cu_handle_));
CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_));
}
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
cudaMalloc(reinterpret_cast<void **>(&A_),
3 * max_batch_size * sizeof(float *));
return true;
}
template <>
bool BatchedGemm<float, float>::run(const float alpha,
const float beta,
const float *a[],
const float *b[],
float *c[],
const int m,
const int n,
const int k,
const int batch_size) {
CHECK(a != nullptr);
CHECK(b != nullptr);
CHECK(c != nullptr);
lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
ldc_ = n;
m_ = m;
n_ = n;
k_ = k;
cudaMemcpyAsync(A_,
a,
batch_size * sizeof(const float *),
cudaMemcpyHostToDevice,
exe_stream_);
cudaMemcpyAsync(A_ + batch_size,
b,
batch_size * sizeof(const float *),
cudaMemcpyHostToDevice,
exe_stream_);
cudaMemcpyAsync(A_ + batch_size * 2,
c,
batch_size * sizeof(float *),
cudaMemcpyHostToDevice,
exe_stream_);
CUBLAS_CALL(cublasSgemmBatched(cu_handle_,
cu_trans_b_,
cu_trans_a_,
n_,
m_,
k_,
&alpha,
const_cast<const float **>(A_ + batch_size),
ldb_,
const_cast<const float **>(A_),
lda_,
&beta,
A_ + batch_size * 2,
ldc_,
batch_size));
return true;
}
template <>
bool BatchedGemm<float, float>::run(const float alpha,
const float beta,
const float *a[],
const int m,
const int n,
const int k,
const int batch_size) {
CHECK(a != nullptr);
lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
ldc_ = n;
m_ = m;
n_ = n;
k_ = k;
cudaMemcpyAsync(A_,
a,
3 * batch_size * sizeof(const float *),
cudaMemcpyDefault,
exe_stream_);
CUBLAS_CALL(cublasSgemmBatched(cu_handle_,
cu_trans_b_,
cu_trans_a_,
n_,
m_,
k_,
&alpha,
const_cast<const float **>(A_ + batch_size),
ldb_,
const_cast<const float **>(A_),
lda_,
&beta,
A_ + batch_size * 2,
ldc_,
batch_size));
return true;
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename PtypeIn, typename PtypeOut>
class BatchedGemm {
public:
BatchedGemm() : cu_handle_(nullptr) {}
~BatchedGemm() {
if (A_ != nullptr) {
cudaFree(A_);
}
}
bool init(const bool trans_a,
const bool trans_b,
const int max_batch_size,
Context<TARGET(kCUDA)>* ctx);
bool run(const PtypeOut alpha,
const PtypeOut beta,
const PtypeIn* a[],
const PtypeIn* b[],
PtypeOut* c[],
const int m,
const int n,
const int k,
const int batch_size);
bool run(const PtypeOut alpha,
const PtypeOut beta,
const PtypeIn* a[],
const int m,
const int n,
const int k,
const int batch_size);
private:
cudaStream_t exe_stream_;
cublasHandle_t cu_handle_;
cublasOperation_t cu_trans_a_;
cublasOperation_t cu_trans_b_;
int m_{-1};
int n_{-1};
int k_{-1};
int lda_{-1};
int ldb_{-1};
int ldc_{-1};
PtypeIn** A_{nullptr};
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -30,6 +30,8 @@ add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm)
add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
......@@ -57,4 +59,6 @@ nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_
if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda)
nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_aligned_mat_mul_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_aligned_mat_mul,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchAlignedMatMulCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/batched_gemm.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/types.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SearchAlignedMatMulCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::MatMulParam;
void PrepareForRun() override {
auto& param = this->Param<param_t>();
CHECK(ctx_) << "running context should be set first";
auto& cuda_ctx = ctx_->template As<CUDAContext>();
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
int seq_num = param.X->lod()[0].size() - 1;
batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm<float, float>);
CHECK(
batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx));
A_ = static_cast<float**>(malloc(3 * seq_num * sizeof(float*)));
CHECK(A_);
}
void Run() override {
auto& param = this->Param<param_t>();
auto x = param.X;
auto y = param.Y;
auto out = param.Out;
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
float alpha = param.alpha;
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
const auto& x_lod = x->lod();
const auto& y_lod = y->lod();
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
int seq_num = x_lod_0.size() - 1;
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = x_transpose ? x_inner_size : x_batch_size;
int N = y_transpose ? y_batch_size : y_inner_size;
int X_K = x_transpose ? x_batch_size : x_inner_size;
int Y_K = y_transpose ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
int K = X_K;
auto x_data = x->data<float>();
auto y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
auto x_stride = x_batch_size * x_inner_size;
auto y_stride = y_batch_size * y_inner_size;
auto out_stride = M * N;
for (int seq = 0; seq < seq_num; seq++) {
A_[seq] = const_cast<float*>(x_data) + seq * x_stride;
A_[seq + seq_num] = const_cast<float*>(y_data) + seq * y_stride;
A_[seq + seq_num * 2] = out_data + seq * out_stride;
}
batched_gemm_impl_->run(
alpha, 0.0f, const_cast<const float**>(A_), M, N, K, seq_num);
}
~SearchAlignedMatMulCompute() {
if (A_ != nullptr) {
free(A_);
}
}
private:
std::unique_ptr<lite::cuda::math::BatchedGemm<float, float>>
batched_gemm_impl_;
float** A_{nullptr};
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_aligned_mat_mul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
void search_aligned_mat_mul_compute_ref(const operators::MatMulParam& param) {
auto x = param.X;
auto y = param.Y;
auto out = param.Out;
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
T alpha = static_cast<T>(param.alpha);
const auto x_dims = x->dims();
const auto y_dims = y->dims();
const auto& x_lod = x->lod();
const auto& y_lod = y->lod();
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
int seq_num = x_lod_0.size() - 1;
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = x_transpose ? x_inner_size : x_batch_size;
int N = y_transpose ? y_batch_size : y_inner_size;
int X_K = x_transpose ? x_batch_size : x_inner_size;
int Y_K = y_transpose ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
int K = X_K;
int lda = x_transpose ? M : K;
int ldb = y_transpose ? K : N;
int ldc = N;
int x_stride = x_batch_size * x_inner_size;
int y_stride = y_batch_size * y_inner_size;
int out_stride = M * N;
auto x_data = x->data<T>();
auto y_data = y->data<T>();
auto out_data = out->mutable_data<T>();
#pragma omp parallel for
for (int seq = 0; seq < seq_num; seq++) {
auto a = x_data + seq * x_stride;
auto b = y_data + seq * y_stride;
auto c = out_data + seq * out_stride;
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
auto sum = static_cast<T>(0);
for (int l = 0; l < K; l++) {
T av;
T bv;
if (x_transpose) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (y_transpose) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
c[i * ldc + j] = alpha * sum;
}
}
}
}
TEST(search_aligned_mat_mul_compute, normal) {
Env<TargetType::kCUDA>::Init();
for (int seq_num : {1, 2}) {
for (int x_batch_size : {1, 3}) {
for (int x_inner_size : {1, 5}) {
for (int out_inner_size : {1, 4}) {
for (bool x_transpose : {true, false}) {
for (bool y_transpose : {true, false}) {
for (float alpha : {1., 2.}) {
// infer x_dims and y_dims
int y_batch_size;
int y_inner_size;
int out_batch_size;
if (x_transpose) {
if (y_transpose) {
y_batch_size = out_inner_size;
y_inner_size = x_batch_size;
out_batch_size = x_inner_size;
} else {
y_batch_size = x_batch_size;
y_inner_size = out_inner_size;
out_batch_size = x_inner_size;
}
} else {
if (y_transpose) {
y_batch_size = out_inner_size;
y_inner_size = x_inner_size;
out_batch_size = x_batch_size;
} else {
y_batch_size = x_inner_size;
y_inner_size = out_inner_size;
out_batch_size = x_batch_size;
}
}
std::vector<uint64_t> x_lod_0(seq_num + 1);
std::vector<uint64_t> y_lod_0(seq_num + 1);
std::vector<uint64_t> out_lod_0(seq_num + 1);
x_lod_0[0] = 0;
y_lod_0[0] = 0;
out_lod_0[0] = 0;
for (int i = 0; i < seq_num; i++) {
x_lod_0[i + 1] = x_lod_0[i] + x_batch_size;
y_lod_0[i + 1] = y_lod_0[i] + y_batch_size;
out_lod_0[i + 1] = out_lod_0[i] + out_batch_size;
}
LoD x_lod;
LoD y_lod;
LoD out_lod;
x_lod.push_back(x_lod_0);
y_lod.push_back(y_lod_0);
out_lod.push_back(out_lod_0);
DDim x_dims({static_cast<int64_t>(x_lod_0.back()),
static_cast<int64_t>(x_inner_size)});
DDim y_dims({static_cast<int64_t>(y_lod_0.back()),
static_cast<int64_t>(y_inner_size)});
DDim out_dims({static_cast<int64_t>(out_lod_0.back()),
static_cast<int64_t>(out_inner_size)});
// prepare input&output tensors
Tensor x_dev, x_host, y_dev, y_host, out_dev, out_host, out_ref;
x_host.Resize(x_dims);
y_host.Resize(y_dims);
out_host.Resize(out_dims);
x_dev.Resize(x_dims);
y_dev.Resize(y_dims);
out_dev.Resize(out_dims);
out_ref.Resize(out_dims);
x_host.set_lod(x_lod);
y_host.set_lod(y_lod);
out_host.set_lod(out_lod);
x_dev.set_lod(x_lod);
y_dev.set_lod(y_lod);
out_dev.set_lod(out_lod);
out_ref.set_lod(out_lod);
auto out_dev_data = out_dev.mutable_data<float>(TARGET(kCUDA));
auto x_host_data = x_host.mutable_data<float>();
auto y_host_data = y_host.mutable_data<float>();
auto out_host_data = out_host.mutable_data<float>();
auto out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_host.dims().production(); i++) {
x_host_data[i] = i * 0.125f;
}
for (int i = 0; i < y_host.dims().production(); i++) {
y_host_data[i] = i * 0.5f;
}
x_dev.Assign<float, lite::DDim, TARGET(kCUDA)>(x_host_data,
x_host.dims());
y_dev.Assign<float, lite::DDim, TARGET(kCUDA)>(y_host_data,
y_host.dims());
// prepare cuda context, initialize param, and run kernel
operators::MatMulParam param;
param.X = &x_dev;
param.Y = &y_dev;
param.Out = &out_dev;
param.alpha = alpha;
param.transpose_X = x_transpose;
param.transpose_Y = y_transpose;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda_ctx = ctx->As<CUDAContext>();
cuda_ctx.InitOnce();
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
cuda_ctx.Init(dev_id);
SearchAlignedMatMulCompute search_aligned_mat_mul;
search_aligned_mat_mul.SetParam(param);
search_aligned_mat_mul.SetContext(std::move(ctx));
search_aligned_mat_mul.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_host_data,
out_dev_data,
sizeof(float) * out_dev.dims().production(),
IoDirection::DtoH);
// run reference
param.X = &x_host;
param.Y = &y_host;
param.Out = &out_ref;
search_aligned_mat_mul_compute_ref<float>(param);
// verify result
for (int i = 0; i < out_ref.dims().production(); i++) {
EXPECT_NEAR(out_host_data[i], out_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_seq_fc_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename dtype>
__global__ void add_bias(int n,
int output_size,
const dtype* bias,
dtype* dout) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int bias_index = index % output_size;
if (index < n) {
dout[index] = dout[index] + bias[bias_index];
}
}
void SearchSeqFcCompute::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
void SearchSeqFcCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(ctx_) << "running context should be set first";
auto& cuda_ctx = ctx_->template As<CUDAContext>();
auto cuda_stream = cuda_ctx.exec_stream();
auto x = param.x;
auto w = param.w;
auto b = param.b;
auto out = param.out;
auto out_size = param.out_size;
const auto x_dims = x->dims();
const auto w_dims = w->dims();
const auto out_dims = out->dims();
CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor.";
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
CHECK_EQ(out_dims.size(), 2) << "The Output(Out) should be 2-D tensor.";
CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]";
CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size";
CHECK_EQ(out_dims[0], x_dims[0]) << "Wrong shape: out_dims[0] != x_dims[0]";
CHECK_EQ(out_dims[1], out_size) << "Wrong shape: out_dims[1] != out_size";
int M = x_dims[0];
int K = x_dims[1];
int N = w_dims[0];
auto x_data = x->data<float>();
auto w_data = w->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
CHECK(gemm_impl_->init(false, true, M, N, K, &cuda_ctx));
gemm_impl_->run(1.0f, 0.0f, x_data, w_data, out_data, &cuda_ctx);
if (b != nullptr) {
auto b_dims = b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]";
auto b_data = b->mutable_data<float>();
int total_size = M * N;
add_bias<float><<<CUDA_GET_BLOCKS(total_size),
CUDA_NUM_THREADS,
0,
cuda_stream>>>(total_size, N, b_data, out_data);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_seq_fc,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchSeqFcCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/types.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SearchSeqFcCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SearchSeqFcParam;
void PrepareForRun() override;
void Run() override;
virtual ~SearchSeqFcCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_{nullptr};
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_seq_fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
void search_seq_fc_compute_ref(const operators::SearchSeqFcParam& param) {
auto x = param.x;
auto w = param.w;
auto b = param.b;
auto out = param.out;
auto out_size = param.out_size;
const auto x_dims = x->dims();
const auto w_dims = w->dims();
const auto& x_lod = x->lod();
CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor.";
CHECK(!x_lod.empty()) << "The Input(X) must hold lod info.";
const auto& x_lod_0 = x_lod[0];
CHECK_GE(x_lod_0.size(), 2) << "The Input(X)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]";
CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size";
int M = x_dims[0];
int K = x_dims[1];
int N = w_dims[0];
auto x_data = x->data<T>();
auto w_data = w->data<T>();
auto out_data = out->mutable_data<T>();
#pragma omp parallel for
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
auto sum = static_cast<T>(0);
for (int l = 0; l < K; l++) {
T xv = x_data[i * K + l];
T wv = w_data[j * K + l];
sum += xv * wv;
}
out_data[i * N + j] = sum;
}
}
if (b != nullptr) {
auto b_dims = b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]";
auto b_data = b->data<T>();
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
out_data[i * N + j] += b_data[j];
}
}
}
}
TEST(search_seq_fc_compute, normal) {
Env<TargetType::kCUDA>::Init();
for (auto x_lod_0 : {std::vector<uint64_t>({0, 1, 3}),
std::vector<uint64_t>({0, 3, 4, 5})}) {
for (auto feature_size : {2, 9}) {
for (auto out_size : {3, 5}) {
for (auto has_bias : {true, false}) {
// infer x_dims, w_dims, b_dims and out_dims
DDim x_dims({static_cast<int64_t>(x_lod_0.back()), feature_size});
DDim w_dims({out_size, feature_size});
DDim b_dims({has_bias ? out_size : 0});
DDim out_dims({static_cast<int64_t>(x_lod_0.back()), out_size});
LoD x_lod;
x_lod.push_back(x_lod_0);
LoD out_lod;
out_lod.push_back(x_lod_0);
// prepare input&output tensors
Tensor x_dev, x_host, w_dev, w_host, b_dev, b_host, out_dev, out_host,
out_ref;
x_host.Resize(x_dims);
w_host.Resize(w_dims);
b_host.Resize(b_dims);
out_host.Resize(out_dims);
x_dev.Resize(x_dims);
w_dev.Resize(w_dims);
b_dev.Resize(b_dims);
out_dev.Resize(out_dims);
out_ref.Resize(out_dims);
x_host.set_lod(x_lod);
out_host.set_lod(out_lod);
x_dev.set_lod(x_lod);
out_dev.set_lod(out_lod);
out_ref.set_lod(out_lod);
auto out_dev_data = out_dev.mutable_data<float>(TARGET(kCUDA));
auto x_host_data = x_host.mutable_data<float>();
auto w_host_data = w_host.mutable_data<float>();
auto out_host_data = out_host.mutable_data<float>();
auto out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_host.dims().production(); i++) {
x_host_data[i] = i * 0.125f;
}
for (int i = 0; i < w_host.dims().production(); i++) {
w_host_data[i] = i * 0.5f;
}
x_dev.Assign<float, lite::DDim, TARGET(kCUDA)>(x_host_data,
x_host.dims());
w_dev.Assign<float, lite::DDim, TARGET(kCUDA)>(w_host_data,
w_host.dims());
// prepare cuda context, initialize param, and run kernel
operators::SearchSeqFcParam param;
param.x = &x_dev;
param.w = &w_dev;
param.out = &out_dev;
param.out_size = out_size;
if (has_bias) {
auto b_host_data = b_host.mutable_data<float>();
for (int i = 0; i < b_host.dims().production(); i++) {
b_host_data[i] = i * 0.5f;
}
b_dev.Assign<float, lite::DDim, TARGET(kCUDA)>(b_host_data,
b_host.dims());
param.b = &b_dev;
}
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda_ctx = ctx->As<CUDAContext>();
cuda_ctx.InitOnce();
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
cuda_ctx.Init(dev_id);
SearchSeqFcCompute search_seq_fc;
search_seq_fc.SetParam(param);
search_seq_fc.SetContext(std::move(ctx));
search_seq_fc.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(out_host_data,
out_dev_data,
sizeof(float) * out_dev.dims().production(),
IoDirection::DtoH);
// run reference
param.x = &x_host;
param.w = &w_host;
param.out = &out_ref;
if (has_bias) {
param.b = &b_host;
}
search_seq_fc_compute_ref<float>(param);
// verify result
for (int i = 0; i < out_ref.dims().production(); i++) {
EXPECT_NEAR(out_host_data[i], out_ref_data[i], 1e-5);
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册