From e1b67433a335f44adbd525b5265c4e7bd36b6f54 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 20 Nov 2019 21:54:27 +0800 Subject: [PATCH] fix x86 search_grnn, add cuda search_grnn and unit test (#2448) * fix x86 search_grnn and add unit test * add cuda search_grnn and unit test --- lite/kernels/cuda/CMakeLists.txt | 11 +- lite/kernels/cuda/search_grnn_compute.cu | 351 ++++++++++++++++++ lite/kernels/cuda/search_grnn_compute.h | 46 +++ lite/kernels/cuda/search_grnn_compute_test.cc | 103 +++++ lite/kernels/x86/CMakeLists.txt | 6 +- lite/kernels/x86/search_grnn_compute.h | 1 + lite/kernels/x86/search_grnn_compute_test.cc | 100 +++++ 7 files changed, 611 insertions(+), 7 deletions(-) create mode 100644 lite/kernels/cuda/search_grnn_compute.cu create mode 100644 lite/kernels/cuda/search_grnn_compute.h create mode 100644 lite/kernels/cuda/search_grnn_compute_test.cc create mode 100644 lite/kernels/x86/search_grnn_compute_test.cc diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index ac2e62f8cf..2b7cd648c9 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -23,13 +23,14 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) -add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps}) -add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) +add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm) add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) @@ -49,15 +50,17 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) -nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) -nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda) nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) if(LITE_BUILD_EXTRA) + nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda) + nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda) + nv_test(search_grnn_compute_cuda_test SRCS search_grnn_compute_test.cc DEPS search_grnn_compute_cuda) + nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda) nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda) diff --git a/lite/kernels/cuda/search_grnn_compute.cu b/lite/kernels/cuda/search_grnn_compute.cu new file mode 100644 index 0000000000..468b66e568 --- /dev/null +++ b/lite/kernels/cuda/search_grnn_compute.cu @@ -0,0 +1,351 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/search_grnn_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +template +T sigmoid(T z) { + return 1 / (1 + std::exp(-z)); +} + +template +__global__ void PreComputeKernel( + const int num, const T* w_x_e, const T* wz_x_e, T* tilde, T* z, T* hidden) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + tilde[index] = std::tanh(w_x_e[index]); + z[index] = 1 / (1 + std::exp(-wz_x_e[index])); + hidden[index] = (1. - z[index]) * tilde[index]; + } +} + +template +__global__ void PostComputeKernel(const int start, + const int end, + const int cap_h, + const int w_tm1, + const T* wr_x_e, + const T* ur_x_h, + const T* wz_x_e, + const T* uz_x_h, + const T* w_x_e, + const T* u_x_h, + T* r, + T* z, + T* tilde, + T* hidden) { + int j = start + blockIdx.x * blockDim.x + threadIdx.x; + if (j < end) { + r[j] = 1 / (1 + std::exp(-(wr_x_e[j] + ur_x_h[j]))); + z[j] = 1 / (1 + std::exp(-(wz_x_e[j] + uz_x_h[j]))); + tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]); + hidden[j] = z[j] * hidden[j - cap_h * w_tm1] + (1.0 - z[j]) * tilde[j]; + } +} + +void SearchGrnnCompute::PrepareForRun() { + gemm_impl_.reset(new lite::cuda::math::Gemm); +} + +void SearchGrnnCompute::PrepareLayout(const Tensor* input_blob) { + auto& param = this->Param(); + auto& context = this->ctx_->template As(); + auto cuda_stream = context.exec_stream(); + + auto* _input = input_blob; + int dim0 = _input->dims()[0]; + int dim1 = 1; + if (_input->dims().size() > 1) { + dim1 = _input->dims()[1]; + } + int batch = _input->lod()[0].size() - 1; + auto& offset = _input->lod()[0]; + + idx_sorted_by_width_cpu = std::make_shared(); + idx_sorted_by_width_cpu->Resize({batch}); + int* idx_sorted_by_width_cpu_data = + idx_sorted_by_width_cpu->mutable_data(); + + Tensor _width; + _width.Resize({batch}); + int* width_data = _width.mutable_data(); + // sort sequence by width (descending) and find the largest width in the + // batch + for (int i = 0; i < batch; i++) { + width_data[i] = offset[i + 1] - offset[i]; + idx_sorted_by_width_cpu_data[i] = i; + } + std::sort(idx_sorted_by_width_cpu_data, + idx_sorted_by_width_cpu_data + batch, + [&_width](int a, int b) { + return _width.data()[a] > _width.data()[b]; + }); + int max_width = width_data[idx_sorted_by_width_cpu_data[0]]; + + // start of reorganizing the input + std::vector new_offset; + new_offset.resize(max_width + 1); + new_offset[0] = 0; + int j = batch - 1; + int last_width = 0; + int sub_row = 0; + int sub_col = 0; + + for (int i = 1; i <= max_width;) { + for (int k = j; k >= 0; --k) { + if (width_data[idx_sorted_by_width_cpu_data[k]] > last_width) { + sub_row = width_data[idx_sorted_by_width_cpu_data[k]] - last_width; + sub_col = k + 1; + for (int s = 0; s < sub_row; s++) { + new_offset[i] = new_offset[i - 1] + sub_col; + i++; + } + // move on + last_width = width_data[idx_sorted_by_width_cpu_data[k]]; + j = k - 1; + break; + } + } + } + + // copying to the reorganized buffer + auto* _layout_input = new Tensor(); + auto* _layout_input_gpu = param.layout_input; + if (_input->dims().size() == 1) { + // _layout_input.reshape_batch_sequence({dim0}, new_offset); + LOG(FATAL) << "_input->dims().size() = 1, error."; + } else { + // _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset); + LoD new_lod; + new_lod.push_back(new_offset); + _layout_input->set_lod(new_lod); + _layout_input->Resize({dim0, dim1}); + _layout_input_gpu->set_lod(new_lod); + _layout_input_gpu->Resize({dim0, dim1}); + } + + auto* new_emb = _layout_input->mutable_data(); + auto* input_cpu = new Tensor(); + input_cpu->Resize(_input->dims()); + auto* input_cpu_data = input_cpu->mutable_data(); + TargetW::MemcpyAsync(input_cpu_data, + _input->data(), + _input->numel() * sizeof(float), + IoDirection::DtoH, + cuda_stream); + for (int i = 0; i < max_width; i++) { + int w = new_offset[i + 1] - new_offset[i]; + auto* emb_start = new_emb + dim1 * new_offset[i]; + for (int j = 0; j < w; ++j) { + memcpy(emb_start + dim1 * j, + input_cpu_data + dim1 * offset[idx_sorted_by_width_cpu_data[j]] + + dim1 * i, + dim1 * sizeof(float)); + } + } + + auto* _layout_input_gpu_data = + _layout_input_gpu->mutable_data(TARGET(kCUDA)); + TargetW::MemcpyAsync(_layout_input_gpu_data, + new_emb, + _layout_input->numel() * sizeof(float), + IoDirection::HtoD, + cuda_stream); + delete _layout_input; + delete input_cpu; +} + +void SearchGrnnCompute::CopyBack(float* from, float* to, int step) { + auto& param = this->Param(); + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); + auto* _input = param.x; + auto* _layout_input = param.layout_input; + + const auto& offset = _input->lod()[0]; + const auto& new_offset = _layout_input->lod()[0]; + const auto* idx_sorted_by_width_cpu_data = + idx_sorted_by_width_cpu->data(); + for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) { + int w = new_offset[i + 1] - new_offset[i]; + for (int j = 0; j < w; j++) { + TargetW::MemcpyAsync( + to + step * (offset[idx_sorted_by_width_cpu_data[j]] + i), + from + (new_offset[i] + j) * step, + step * sizeof(float), + IoDirection::DtoD, + stream); + } + } +} + +void SearchGrnnCompute::Run() { + CHECK(ctx_) << "running context should be set first"; + auto& param = this->Param(); + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); + + auto* bottom = param.x; + auto* wi = param.wi; + auto* wh = param.wh; + auto* top = param.out; + auto* _buffer = param.tmp_buffer; + int _cap_h = param.num_hidden; + int _cap_e = param.num_input; + + int _cap_l = bottom->dims()[0]; + int batch = bottom->lod()[0].size() - 1; + + const auto& offset = bottom->lod()[0]; + LoD top_lod; + top_lod.push_back(offset); + top->set_lod(top_lod); + std::vector top_dims_vec{_cap_l, _cap_h}; + top->Resize(top_dims_vec); + auto* top_hidden = top->mutable_data(TARGET(kCUDA)); + + const auto* dense_e2h = wi->data(); + const auto* dense_h2h = wh->data(); + + const auto* e2h = dense_e2h; + const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h; + const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h; + const auto* h2h = dense_h2h; + const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h; + const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h; + + PrepareLayout(bottom); + + auto* _layout_input = param.layout_input; + auto* new_emb = _layout_input->data(); + const auto& new_offset = _layout_input->lod()[0]; + int max_width = _layout_input->lod()[0].size() - 1; + + // this buffer is used for book keeping info which will be used in bp + // buffer also needed in bp, so make it larger + _buffer->Resize({20, _cap_l, _cap_h}); + auto* buffer_data = _buffer->mutable_data(TARGET(kCUDA)); + auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h; + auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h; + auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h; + auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h; + auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h; + auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h; + auto* r = buffer_data + 6 * _cap_l * _cap_h; + auto* z = buffer_data + 7 * _cap_l * _cap_h; + auto* tilde = buffer_data + 8 * _cap_l * _cap_h; + // the internal hidden + auto* hidden = buffer_data + 19 * _cap_l * _cap_h; + + gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); + gemm_impl_->run(1.0f, 0.0f, new_emb, e2h, w_x_e, &context); + gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); + gemm_impl_->run(1.0f, 0.0f, new_emb, e2hr, wr_x_e, &context); + gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); + gemm_impl_->run(1.0f, 0.0f, new_emb, e2hz, wz_x_e, &context); + + // precompute hidden0 + int num = batch * _cap_h; + int threads = 512; + int blocks = (num + threads - 1) / threads; + PreComputeKernel<<>>( + num, w_x_e, wz_x_e, tilde, z, hidden); + + // recurrence + for (int i = 1; i < max_width; i++) { + int w_tm1 = new_offset[i] - new_offset[i - 1]; + int w = new_offset[i + 1] - new_offset[i]; + + // precompute hidden i-1 to hidden i + auto* htm1 = hidden + new_offset[i - 1] * _cap_h; + + gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); + gemm_impl_->run( + 1.0f, 0.0f, htm1, h2h, u_x_h + new_offset[i] * _cap_h, &context); + gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); + gemm_impl_->run( + 1.0f, 0.0f, htm1, h2hr, ur_x_h + new_offset[i] * _cap_h, &context); + gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); + gemm_impl_->run( + 1.0f, 0.0f, htm1, h2hz, uz_x_h + new_offset[i] * _cap_h, &context); + + // compute the gate and hidden + int start = new_offset[i] * _cap_h; + int end = (new_offset[i] + w) * _cap_h; + PostComputeKernel<<>>(start, + end, + _cap_h, + w_tm1, + wr_x_e, + ur_x_h, + wz_x_e, + uz_x_h, + w_x_e, + u_x_h, + r, + z, + tilde, + hidden); + } + + CopyBack(hidden, top_hidden, _cap_h); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_grnn, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SearchGrnnCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Wi", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Wh", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("tmp_buffer", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("idx_sorted_by_width", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("layout_input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/search_grnn_compute.h b/lite/kernels/cuda/search_grnn_compute.h new file mode 100644 index 0000000000..73d84635d0 --- /dev/null +++ b/lite/kernels/cuda/search_grnn_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/blas.h" +#include "lite/backends/cuda/math/gemm.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SearchGrnnCompute + : public KernelLite { + public: + using param_t = operators::SearchGrnnParam; + using TargetW = TargetWrapper; + + void PrepareForRun() override; + void Run() override; + virtual ~SearchGrnnCompute() = default; + + private: + std::shared_ptr idx_sorted_by_width_cpu; + std::unique_ptr> gemm_impl_; + void PrepareLayout(const Tensor* input); + void CopyBack(float* from, float* to, int step); +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_grnn_compute_test.cc b/lite/kernels/cuda/search_grnn_compute_test.cc new file mode 100644 index 0000000000..08b96e1f1e --- /dev/null +++ b/lite/kernels/cuda/search_grnn_compute_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_grnn_compute.h" +#include +#include +#include +#include +#include "lite/api/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +TEST(search_grnn, normal) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + SearchGrnnCompute kernel; + operators::SearchGrnnParam param; + + int num_input = 6; + int num_hidden = 6; + int num_batch = 3; + Tensor x, wi, wh, out, idx_sorted_by_width, layout_input, tmp_buffer; + x.Resize({num_batch, num_input}); + wi.Resize({3, num_hidden, num_input}); + wh.Resize({3, num_hidden, num_hidden}); + LoD x_lod{}; + x_lod.push_back({0, 1, 3}); + x.set_lod(x_lod); + + Tensor x_cpu, wi_cpu, wh_cpu, out_cpu, layout_input_cpu, tmp_buffer_cpu; + x_cpu.Resize({num_batch, num_input}); + wi_cpu.Resize({3, num_hidden, num_input}); + wh_cpu.Resize({3, num_hidden, num_hidden}); + out_cpu.Resize({num_batch, num_hidden}); + layout_input_cpu.Resize({num_batch, num_input}); + tmp_buffer_cpu.Resize({20, num_batch, num_hidden}); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = static_cast(i); + } + auto* wi_cpu_data = wi_cpu.mutable_data(); + for (int i = 0; i < wi_cpu.numel(); ++i) { + wi_cpu_data[i] = static_cast(i); + } + auto* wh_cpu_data = wh_cpu.mutable_data(); + for (int i = 0; i < wh_cpu.numel(); ++i) { + wh_cpu_data[i] = static_cast(i); + } + + x.Assign(x_cpu_data, x_cpu.dims()); + wi.Assign(wi_cpu_data, wi_cpu.dims()); + wh.Assign(wh_cpu_data, wh_cpu.dims()); + + param.x = &x; + param.wi = &wi; + param.wh = &wh; + param.out = &out; + param.idx_sorted_by_width = &idx_sorted_by_width; + param.layout_input = &layout_input; + param.tmp_buffer = &tmp_buffer; + param.num_input = num_input; + param.num_hidden = num_hidden; + kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + kernel.SetContext(std::move(ctx)); + kernel.Launch(); + cudaDeviceSynchronize(); + + auto* out_cpu_data = out_cpu.mutable_data(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + LOG(INFO) << "out_data:"; + for (int i = 0; i < out.numel(); i++) { + // EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-5); + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index ed2ac6637e..80ae34be41 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -46,7 +46,7 @@ add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(match_matrix_tensor_compute_x86 X86 basic SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps} blas math_function) add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_compute.cc DEPS ${lite_kernel_deps}) -add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps} blas math_function) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type) add_kernel(attention_padding_mask_compute_x86 X86 basic SRCS attention_padding_mask_compute.cc DEPS ${lite_kernel_deps}) @@ -89,12 +89,12 @@ lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_ lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86) lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86) - +lite_cc_test(test_search_grnn_compute_x86 SRCS search_grnn_compute_test.cc DEPS search_grnn_compute_x86) +lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86) lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86) lite_cc_test(test_stack_compute_x86 SRCS stack_compute_test.cc DEPS stack_compute_x86) lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86) lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) -lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86) lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86) lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86) diff --git a/lite/kernels/x86/search_grnn_compute.h b/lite/kernels/x86/search_grnn_compute.h index 5082faf45b..66866761e1 100644 --- a/lite/kernels/x86/search_grnn_compute.h +++ b/lite/kernels/x86/search_grnn_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "lite/backends/x86/math/blas.h" #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" diff --git a/lite/kernels/x86/search_grnn_compute_test.cc b/lite/kernels/x86/search_grnn_compute_test.cc new file mode 100644 index 0000000000..b85d97e3f1 --- /dev/null +++ b/lite/kernels/x86/search_grnn_compute_test.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_grnn_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(search_grnn_x86, retrive_op) { + auto kernel = + KernelRegistry::Global().Create( + "search_grnn"); + ASSERT_FALSE(kernel.empty()); + ASSERT_TRUE(kernel.front()); +} + +TEST(search_grnn_x86, init) { + SearchGrnnCompute ssdc; + ASSERT_EQ(ssdc.precision(), PRECISION(kFloat)); + ASSERT_EQ(ssdc.target(), TARGET(kX86)); +} + +TEST(search_grnn_x86, run_test) { + int num_input = 128; + int num_hidden = 128; + int num_batch = 3; + lite::Tensor x, wi, wh, out, idx_sorted_by_width, layout_input, tmp_buffer; + x.Resize({num_batch, num_input}); + wi.Resize({3, num_hidden, num_input}); + wh.Resize({3, num_hidden, num_hidden}); + // out.Resize({num_batch, num_hidden}); + LoD x_lod{}; + x_lod.push_back({0, 1, 3}); + x.set_lod(x_lod); + + auto* x_data = x.mutable_data(); + for (int64_t i = 0; i < x.numel(); i++) { + x_data[i] = static_cast(i); + } + auto* wi_data = wi.mutable_data(); + for (int64_t i = 0; i < wi.numel(); i++) { + wi_data[i] = static_cast(i); + } + auto* wh_data = wh.mutable_data(); + for (int64_t i = 0; i < wh.numel(); i++) { + wh_data[i] = static_cast(i); + } + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + + operators::SearchGrnnParam param; + param.x = &x; + param.wi = &wi; + param.wh = &wh; + param.out = &out; + param.idx_sorted_by_width = &idx_sorted_by_width; + param.layout_input = &layout_input; + param.tmp_buffer = &tmp_buffer; + param.num_input = num_input; + param.num_hidden = num_hidden; + + SearchGrnnCompute sgc; + sgc.SetContext(std::move(ctx)); + sgc.SetParam(param); + sgc.Run(); + + // std::vector ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19}; + auto* out_data = out.mutable_data(); + LOG(INFO) << out.numel(); + for (int i = 0; i < out.numel(); i++) { + // EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + LOG(INFO) << out_data[i]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(search_grnn, kX86, kFloat, kNCHW, def); -- GitLab