diff --git a/lite/backends/cuda/cuda_utils.h b/lite/backends/cuda/cuda_utils.h index 13bf8190efe1592e7509039a569d31f6bddc5b66..9da70262f5b2e32ae8509d9370142b2499886bfb 100644 --- a/lite/backends/cuda/cuda_utils.h +++ b/lite/backends/cuda/cuda_utils.h @@ -56,6 +56,15 @@ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << CudnnGetErrorInfo(status); \ } +const int CUDA_NUM_THREADS = 512; +// CUDA: number of blocks for threads. +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} +inline int CUDA_GET_BLOCKS(const int N, const int base) { + return (N + base - 1) / base; +} + namespace paddle { namespace lite { namespace cuda { diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index ff690e52a2a399192eddec466a763149fcce71d7..1829bcf330aba31708ac97c97d093afbda197908 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -13,6 +13,7 @@ nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps}) nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) +nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) set ( math_cuda @@ -23,6 +24,7 @@ set ( cuda_transpose cuda_elementwise cuda_gemm + cuda_batched_gemm ) set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/batched_gemm.cc b/lite/backends/cuda/math/batched_gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..e81510927615daa88e7f5bef3ce7b8421d8f6539 --- /dev/null +++ b/lite/backends/cuda/math/batched_gemm.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/batched_gemm.h" +#include +#include "lite/core/device_info.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template <> +bool BatchedGemm::init(const bool trans_a, + const bool trans_b, + const int max_batch_size, + Context *ctx) { + if (cu_handle_ == nullptr) { + this->exe_stream_ = ctx->exec_stream(); + CUBLAS_CALL(cublasCreate(&cu_handle_)); + CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_)); + } + cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + cudaMalloc(reinterpret_cast(&A_), + 3 * max_batch_size * sizeof(float *)); + return true; +} + +template <> +bool BatchedGemm::run(const float alpha, + const float beta, + const float *a[], + const float *b[], + float *c[], + const int m, + const int n, + const int k, + const int batch_size) { + CHECK(a != nullptr); + CHECK(b != nullptr); + CHECK(c != nullptr); + lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m; + ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k; + ldc_ = n; + m_ = m; + n_ = n; + k_ = k; + cudaMemcpyAsync(A_, + a, + batch_size * sizeof(const float *), + cudaMemcpyHostToDevice, + exe_stream_); + cudaMemcpyAsync(A_ + batch_size, + b, + batch_size * sizeof(const float *), + cudaMemcpyHostToDevice, + exe_stream_); + cudaMemcpyAsync(A_ + batch_size * 2, + c, + batch_size * sizeof(float *), + cudaMemcpyHostToDevice, + exe_stream_); + CUBLAS_CALL(cublasSgemmBatched(cu_handle_, + cu_trans_b_, + cu_trans_a_, + n_, + m_, + k_, + &alpha, + const_cast(A_ + batch_size), + ldb_, + const_cast(A_), + lda_, + &beta, + A_ + batch_size * 2, + ldc_, + batch_size)); + return true; +} + +template <> +bool BatchedGemm::run(const float alpha, + const float beta, + const float *a[], + const int m, + const int n, + const int k, + const int batch_size) { + CHECK(a != nullptr); + lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m; + ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k; + ldc_ = n; + m_ = m; + n_ = n; + k_ = k; + cudaMemcpyAsync(A_, + a, + 3 * batch_size * sizeof(const float *), + cudaMemcpyDefault, + exe_stream_); + CUBLAS_CALL(cublasSgemmBatched(cu_handle_, + cu_trans_b_, + cu_trans_a_, + n_, + m_, + k_, + &alpha, + const_cast(A_ + batch_size), + ldb_, + const_cast(A_), + lda_, + &beta, + A_ + batch_size * 2, + ldc_, + batch_size)); + return true; +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/batched_gemm.h b/lite/backends/cuda/math/batched_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2b91d3a524596bf03b4a26a81c14eddcfe64452f --- /dev/null +++ b/lite/backends/cuda/math/batched_gemm.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class BatchedGemm { + public: + BatchedGemm() : cu_handle_(nullptr) {} + ~BatchedGemm() { + if (A_ != nullptr) { + cudaFree(A_); + } + } + + bool init(const bool trans_a, + const bool trans_b, + const int max_batch_size, + Context* ctx); + + bool run(const PtypeOut alpha, + const PtypeOut beta, + const PtypeIn* a[], + const PtypeIn* b[], + PtypeOut* c[], + const int m, + const int n, + const int k, + const int batch_size); + + bool run(const PtypeOut alpha, + const PtypeOut beta, + const PtypeIn* a[], + const int m, + const int n, + const int k, + const int batch_size); + + private: + cudaStream_t exe_stream_; + cublasHandle_t cu_handle_; + cublasOperation_t cu_trans_a_; + cublasOperation_t cu_trans_b_; + int m_{-1}; + int n_{-1}; + int k_{-1}; + int lda_{-1}; + int ldb_{-1}; + int ldc_{-1}; + PtypeIn** A_{nullptr}; +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 0bdf478241b63322df7eda41435a543bd094a5c1..a61aebb8f8e2f2ed5ceda3640c442dcc09c9a8a6 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -30,6 +30,8 @@ add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_ add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps}) add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) +add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm) +add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) @@ -57,4 +59,6 @@ nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_ if(LITE_BUILD_EXTRA) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) + nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda) + nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda) endif() diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.cc b/lite/kernels/cuda/search_aligned_mat_mul_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..525765de283011535fbe154e31eb0afa2dee0daf --- /dev/null +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_aligned_mat_mul_compute.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda {} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_aligned_mat_mul, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SearchAlignedMatMulCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.h b/lite/kernels/cuda/search_aligned_mat_mul_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b1c4552d9c43e2dcbc3bf0211f7028811410cb6c --- /dev/null +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.h @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/math/batched_gemm.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/types.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SearchAlignedMatMulCompute + : public KernelLite { + public: + using param_t = operators::MatMulParam; + + void PrepareForRun() override { + auto& param = this->Param(); + CHECK(ctx_) << "running context should be set first"; + auto& cuda_ctx = ctx_->template As(); + bool x_transpose = param.transpose_X; + bool y_transpose = param.transpose_Y; + int seq_num = param.X->lod()[0].size() - 1; + batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm); + CHECK( + batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); + A_ = static_cast(malloc(3 * seq_num * sizeof(float*))); + CHECK(A_); + } + + void Run() override { + auto& param = this->Param(); + auto x = param.X; + auto y = param.Y; + auto out = param.Out; + bool x_transpose = param.transpose_X; + bool y_transpose = param.transpose_Y; + float alpha = param.alpha; + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + const auto& x_lod = x->lod(); + const auto& y_lod = y->lod(); + const auto& x_lod_0 = x_lod[0]; + const auto& y_lod_0 = y_lod[0]; + int seq_num = x_lod_0.size() - 1; + int x_inner_size = x_dims[1]; + int y_inner_size = y_dims[1]; + int x_batch_size = x_lod_0[1]; + int y_batch_size = y_lod_0[1]; + int M = x_transpose ? x_inner_size : x_batch_size; + int N = y_transpose ? y_batch_size : y_inner_size; + int X_K = x_transpose ? x_batch_size : x_inner_size; + int Y_K = y_transpose ? y_inner_size : y_batch_size; + CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal"; + int K = X_K; + + auto x_data = x->data(); + auto y_data = y->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + auto x_stride = x_batch_size * x_inner_size; + auto y_stride = y_batch_size * y_inner_size; + auto out_stride = M * N; + for (int seq = 0; seq < seq_num; seq++) { + A_[seq] = const_cast(x_data) + seq * x_stride; + A_[seq + seq_num] = const_cast(y_data) + seq * y_stride; + A_[seq + seq_num * 2] = out_data + seq * out_stride; + } + batched_gemm_impl_->run( + alpha, 0.0f, const_cast(A_), M, N, K, seq_num); + } + + ~SearchAlignedMatMulCompute() { + if (A_ != nullptr) { + free(A_); + } + } + + private: + std::unique_ptr> + batched_gemm_impl_; + float** A_{nullptr}; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc b/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..66b3478f64f1144aa09404cb943c3de49e549b0d --- /dev/null +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_aligned_mat_mul_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void search_aligned_mat_mul_compute_ref(const operators::MatMulParam& param) { + auto x = param.X; + auto y = param.Y; + auto out = param.Out; + bool x_transpose = param.transpose_X; + bool y_transpose = param.transpose_Y; + T alpha = static_cast(param.alpha); + const auto x_dims = x->dims(); + const auto y_dims = y->dims(); + const auto& x_lod = x->lod(); + const auto& y_lod = y->lod(); + const auto& x_lod_0 = x_lod[0]; + const auto& y_lod_0 = y_lod[0]; + int seq_num = x_lod_0.size() - 1; + int x_inner_size = x_dims[1]; + int y_inner_size = y_dims[1]; + int x_batch_size = x_lod_0[1]; + int y_batch_size = y_lod_0[1]; + int M = x_transpose ? x_inner_size : x_batch_size; + int N = y_transpose ? y_batch_size : y_inner_size; + int X_K = x_transpose ? x_batch_size : x_inner_size; + int Y_K = y_transpose ? y_inner_size : y_batch_size; + CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal"; + int K = X_K; + int lda = x_transpose ? M : K; + int ldb = y_transpose ? K : N; + int ldc = N; + int x_stride = x_batch_size * x_inner_size; + int y_stride = y_batch_size * y_inner_size; + int out_stride = M * N; + auto x_data = x->data(); + auto y_data = y->data(); + auto out_data = out->mutable_data(); +#pragma omp parallel for + for (int seq = 0; seq < seq_num; seq++) { + auto a = x_data + seq * x_stride; + auto b = y_data + seq * y_stride; + auto c = out_data + seq * out_stride; + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + auto sum = static_cast(0); + for (int l = 0; l < K; l++) { + T av; + T bv; + if (x_transpose) { + av = a[l * lda + i]; + } else { + av = a[i * lda + l]; + } + if (y_transpose) { + bv = b[j * ldb + l]; + } else { + bv = b[l * ldb + j]; + } + sum += av * bv; + } + c[i * ldc + j] = alpha * sum; + } + } + } +} + +TEST(search_aligned_mat_mul_compute, normal) { + Env::Init(); + for (int seq_num : {1, 2}) { + for (int x_batch_size : {1, 3}) { + for (int x_inner_size : {1, 5}) { + for (int out_inner_size : {1, 4}) { + for (bool x_transpose : {true, false}) { + for (bool y_transpose : {true, false}) { + for (float alpha : {1., 2.}) { + // infer x_dims and y_dims + int y_batch_size; + int y_inner_size; + int out_batch_size; + if (x_transpose) { + if (y_transpose) { + y_batch_size = out_inner_size; + y_inner_size = x_batch_size; + out_batch_size = x_inner_size; + } else { + y_batch_size = x_batch_size; + y_inner_size = out_inner_size; + out_batch_size = x_inner_size; + } + } else { + if (y_transpose) { + y_batch_size = out_inner_size; + y_inner_size = x_inner_size; + out_batch_size = x_batch_size; + } else { + y_batch_size = x_inner_size; + y_inner_size = out_inner_size; + out_batch_size = x_batch_size; + } + } + std::vector x_lod_0(seq_num + 1); + std::vector y_lod_0(seq_num + 1); + std::vector out_lod_0(seq_num + 1); + x_lod_0[0] = 0; + y_lod_0[0] = 0; + out_lod_0[0] = 0; + for (int i = 0; i < seq_num; i++) { + x_lod_0[i + 1] = x_lod_0[i] + x_batch_size; + y_lod_0[i + 1] = y_lod_0[i] + y_batch_size; + out_lod_0[i + 1] = out_lod_0[i] + out_batch_size; + } + LoD x_lod; + LoD y_lod; + LoD out_lod; + x_lod.push_back(x_lod_0); + y_lod.push_back(y_lod_0); + out_lod.push_back(out_lod_0); + DDim x_dims({static_cast(x_lod_0.back()), + static_cast(x_inner_size)}); + DDim y_dims({static_cast(y_lod_0.back()), + static_cast(y_inner_size)}); + DDim out_dims({static_cast(out_lod_0.back()), + static_cast(out_inner_size)}); + // prepare input&output tensors + Tensor x_dev, x_host, y_dev, y_host, out_dev, out_host, out_ref; + x_host.Resize(x_dims); + y_host.Resize(y_dims); + out_host.Resize(out_dims); + x_dev.Resize(x_dims); + y_dev.Resize(y_dims); + out_dev.Resize(out_dims); + out_ref.Resize(out_dims); + x_host.set_lod(x_lod); + y_host.set_lod(y_lod); + out_host.set_lod(out_lod); + x_dev.set_lod(x_lod); + y_dev.set_lod(y_lod); + out_dev.set_lod(out_lod); + out_ref.set_lod(out_lod); + auto out_dev_data = out_dev.mutable_data(TARGET(kCUDA)); + auto x_host_data = x_host.mutable_data(); + auto y_host_data = y_host.mutable_data(); + auto out_host_data = out_host.mutable_data(); + auto out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < x_host.dims().production(); i++) { + x_host_data[i] = i * 0.125f; + } + for (int i = 0; i < y_host.dims().production(); i++) { + y_host_data[i] = i * 0.5f; + } + x_dev.Assign(x_host_data, + x_host.dims()); + y_dev.Assign(y_host_data, + y_host.dims()); + // prepare cuda context, initialize param, and run kernel + operators::MatMulParam param; + param.X = &x_dev; + param.Y = &y_dev; + param.Out = &out_dev; + param.alpha = alpha; + param.transpose_X = x_transpose; + param.transpose_Y = y_transpose; + std::unique_ptr ctx(new KernelContext); + auto& cuda_ctx = ctx->As(); + cuda_ctx.InitOnce(); + int dev_id = TargetWrapper::GetCurDevice(); + cuda_ctx.Init(dev_id); + SearchAlignedMatMulCompute search_aligned_mat_mul; + search_aligned_mat_mul.SetParam(param); + search_aligned_mat_mul.SetContext(std::move(ctx)); + search_aligned_mat_mul.Launch(); + cudaDeviceSynchronize(); + CopySync( + out_host_data, + out_dev_data, + sizeof(float) * out_dev.dims().production(), + IoDirection::DtoH); + // run reference + param.X = &x_host; + param.Y = &y_host; + param.Out = &out_ref; + search_aligned_mat_mul_compute_ref(param); + // verify result + for (int i = 0; i < out_ref.dims().production(); i++) { + EXPECT_NEAR(out_host_data[i], out_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_seq_fc_compute.cu b/lite/kernels/cuda/search_seq_fc_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..e3ac75afeeee772ed7486a47dde14b7a3af4085f --- /dev/null +++ b/lite/kernels/cuda/search_seq_fc_compute.cu @@ -0,0 +1,98 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/search_seq_fc_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void add_bias(int n, + int output_size, + const dtype* bias, + dtype* dout) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int bias_index = index % output_size; + if (index < n) { + dout[index] = dout[index] + bias[bias_index]; + } +} + +void SearchSeqFcCompute::PrepareForRun() { + gemm_impl_.reset(new lite::cuda::math::Gemm); +} + +void SearchSeqFcCompute::Run() { + auto& param = this->Param(); + CHECK(ctx_) << "running context should be set first"; + auto& cuda_ctx = ctx_->template As(); + auto cuda_stream = cuda_ctx.exec_stream(); + + auto x = param.x; + auto w = param.w; + auto b = param.b; + auto out = param.out; + auto out_size = param.out_size; + const auto x_dims = x->dims(); + const auto w_dims = w->dims(); + const auto out_dims = out->dims(); + CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor."; + CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor."; + CHECK_EQ(out_dims.size(), 2) << "The Output(Out) should be 2-D tensor."; + CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]"; + CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size"; + CHECK_EQ(out_dims[0], x_dims[0]) << "Wrong shape: out_dims[0] != x_dims[0]"; + CHECK_EQ(out_dims[1], out_size) << "Wrong shape: out_dims[1] != out_size"; + int M = x_dims[0]; + int K = x_dims[1]; + int N = w_dims[0]; + auto x_data = x->data(); + auto w_data = w->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + + CHECK(gemm_impl_->init(false, true, M, N, K, &cuda_ctx)); + gemm_impl_->run(1.0f, 0.0f, x_data, w_data, out_data, &cuda_ctx); + + if (b != nullptr) { + auto b_dims = b->dims(); + CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor."; + CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]"; + auto b_data = b->mutable_data(); + int total_size = M * N; + add_bias<<>>(total_size, N, b_data, out_data); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_seq_fc, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SearchSeqFcCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("b", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/search_seq_fc_compute.h b/lite/kernels/cuda/search_seq_fc_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..dff8ba2acfbe28fc72f095294ad5a140ed66f150 --- /dev/null +++ b/lite/kernels/cuda/search_seq_fc_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/math/gemm.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/types.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SearchSeqFcCompute : public KernelLite { + public: + using param_t = operators::SearchSeqFcParam; + + void PrepareForRun() override; + void Run() override; + virtual ~SearchSeqFcCompute() = default; + + private: + std::unique_ptr> gemm_impl_{nullptr}; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_seq_fc_compute_test.cc b/lite/kernels/cuda/search_seq_fc_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b9beb7b09290e81f17ff2580ff68f4592c9b132 --- /dev/null +++ b/lite/kernels/cuda/search_seq_fc_compute_test.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_seq_fc_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void search_seq_fc_compute_ref(const operators::SearchSeqFcParam& param) { + auto x = param.x; + auto w = param.w; + auto b = param.b; + auto out = param.out; + auto out_size = param.out_size; + const auto x_dims = x->dims(); + const auto w_dims = w->dims(); + const auto& x_lod = x->lod(); + CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor."; + CHECK(!x_lod.empty()) << "The Input(X) must hold lod info."; + const auto& x_lod_0 = x_lod[0]; + CHECK_GE(x_lod_0.size(), 2) << "The Input(X)'s lod info is corrupted."; + CHECK_EQ(x_dims[0], static_cast(x_lod_0.back())) + << "The Input(X)'s lod info mismatches the actual tensor shape."; + CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor."; + CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]"; + CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size"; + int M = x_dims[0]; + int K = x_dims[1]; + int N = w_dims[0]; + auto x_data = x->data(); + auto w_data = w->data(); + auto out_data = out->mutable_data(); + +#pragma omp parallel for + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + auto sum = static_cast(0); + for (int l = 0; l < K; l++) { + T xv = x_data[i * K + l]; + T wv = w_data[j * K + l]; + sum += xv * wv; + } + out_data[i * N + j] = sum; + } + } + + if (b != nullptr) { + auto b_dims = b->dims(); + CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor."; + CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]"; + auto b_data = b->data(); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + out_data[i * N + j] += b_data[j]; + } + } + } +} + +TEST(search_seq_fc_compute, normal) { + Env::Init(); + for (auto x_lod_0 : {std::vector({0, 1, 3}), + std::vector({0, 3, 4, 5})}) { + for (auto feature_size : {2, 9}) { + for (auto out_size : {3, 5}) { + for (auto has_bias : {true, false}) { + // infer x_dims, w_dims, b_dims and out_dims + DDim x_dims({static_cast(x_lod_0.back()), feature_size}); + DDim w_dims({out_size, feature_size}); + DDim b_dims({has_bias ? out_size : 0}); + DDim out_dims({static_cast(x_lod_0.back()), out_size}); + LoD x_lod; + x_lod.push_back(x_lod_0); + LoD out_lod; + out_lod.push_back(x_lod_0); + // prepare input&output tensors + Tensor x_dev, x_host, w_dev, w_host, b_dev, b_host, out_dev, out_host, + out_ref; + x_host.Resize(x_dims); + w_host.Resize(w_dims); + b_host.Resize(b_dims); + out_host.Resize(out_dims); + x_dev.Resize(x_dims); + w_dev.Resize(w_dims); + b_dev.Resize(b_dims); + out_dev.Resize(out_dims); + out_ref.Resize(out_dims); + x_host.set_lod(x_lod); + out_host.set_lod(out_lod); + x_dev.set_lod(x_lod); + out_dev.set_lod(out_lod); + out_ref.set_lod(out_lod); + auto out_dev_data = out_dev.mutable_data(TARGET(kCUDA)); + auto x_host_data = x_host.mutable_data(); + auto w_host_data = w_host.mutable_data(); + auto out_host_data = out_host.mutable_data(); + auto out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < x_host.dims().production(); i++) { + x_host_data[i] = i * 0.125f; + } + for (int i = 0; i < w_host.dims().production(); i++) { + w_host_data[i] = i * 0.5f; + } + x_dev.Assign(x_host_data, + x_host.dims()); + w_dev.Assign(w_host_data, + w_host.dims()); + // prepare cuda context, initialize param, and run kernel + operators::SearchSeqFcParam param; + param.x = &x_dev; + param.w = &w_dev; + param.out = &out_dev; + param.out_size = out_size; + if (has_bias) { + auto b_host_data = b_host.mutable_data(); + for (int i = 0; i < b_host.dims().production(); i++) { + b_host_data[i] = i * 0.5f; + } + b_dev.Assign(b_host_data, + b_host.dims()); + param.b = &b_dev; + } + std::unique_ptr ctx(new KernelContext); + auto& cuda_ctx = ctx->As(); + cuda_ctx.InitOnce(); + int dev_id = TargetWrapper::GetCurDevice(); + cuda_ctx.Init(dev_id); + SearchSeqFcCompute search_seq_fc; + search_seq_fc.SetParam(param); + search_seq_fc.SetContext(std::move(ctx)); + search_seq_fc.Launch(); + cudaDeviceSynchronize(); + CopySync(out_host_data, + out_dev_data, + sizeof(float) * out_dev.dims().production(), + IoDirection::DtoH); + // run reference + param.x = &x_host; + param.w = &w_host; + param.out = &out_ref; + if (has_bias) { + param.b = &b_host; + } + search_seq_fc_compute_ref(param); + // verify result + for (int i = 0; i < out_ref.dims().production(); i++) { + EXPECT_NEAR(out_host_data[i], out_ref_data[i], 1e-5); + } + } + } + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle