未验证 提交 e1b67433 编写于 作者: J juncaipeng 提交者: GitHub

fix x86 search_grnn, add cuda search_grnn and unit test (#2448)

* fix x86 search_grnn and add unit test
* add cuda search_grnn and unit test
上级 fa8c8971
......@@ -23,13 +23,14 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm)
add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
......@@ -49,15 +50,17 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
if(LITE_BUILD_EXTRA)
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
nv_test(search_grnn_compute_cuda_test SRCS search_grnn_compute_test.cc DEPS search_grnn_compute_cuda)
nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda)
nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_grnn_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
template <typename T>
T sigmoid(T z) {
return 1 / (1 + std::exp(-z));
}
template <typename T>
__global__ void PreComputeKernel(
const int num, const T* w_x_e, const T* wz_x_e, T* tilde, T* z, T* hidden) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
tilde[index] = std::tanh(w_x_e[index]);
z[index] = 1 / (1 + std::exp(-wz_x_e[index]));
hidden[index] = (1. - z[index]) * tilde[index];
}
}
template <typename T>
__global__ void PostComputeKernel(const int start,
const int end,
const int cap_h,
const int w_tm1,
const T* wr_x_e,
const T* ur_x_h,
const T* wz_x_e,
const T* uz_x_h,
const T* w_x_e,
const T* u_x_h,
T* r,
T* z,
T* tilde,
T* hidden) {
int j = start + blockIdx.x * blockDim.x + threadIdx.x;
if (j < end) {
r[j] = 1 / (1 + std::exp(-(wr_x_e[j] + ur_x_h[j])));
z[j] = 1 / (1 + std::exp(-(wz_x_e[j] + uz_x_h[j])));
tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]);
hidden[j] = z[j] * hidden[j - cap_h * w_tm1] + (1.0 - z[j]) * tilde[j];
}
}
void SearchGrnnCompute::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
void SearchGrnnCompute::PrepareLayout(const Tensor* input_blob) {
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto cuda_stream = context.exec_stream();
auto* _input = input_blob;
int dim0 = _input->dims()[0];
int dim1 = 1;
if (_input->dims().size() > 1) {
dim1 = _input->dims()[1];
}
int batch = _input->lod()[0].size() - 1;
auto& offset = _input->lod()[0];
idx_sorted_by_width_cpu = std::make_shared<Tensor>();
idx_sorted_by_width_cpu->Resize({batch});
int* idx_sorted_by_width_cpu_data =
idx_sorted_by_width_cpu->mutable_data<int>();
Tensor _width;
_width.Resize({batch});
int* width_data = _width.mutable_data<int>();
// sort sequence by width (descending) and find the largest width in the
// batch
for (int i = 0; i < batch; i++) {
width_data[i] = offset[i + 1] - offset[i];
idx_sorted_by_width_cpu_data[i] = i;
}
std::sort(idx_sorted_by_width_cpu_data,
idx_sorted_by_width_cpu_data + batch,
[&_width](int a, int b) {
return _width.data<int>()[a] > _width.data<int>()[b];
});
int max_width = width_data[idx_sorted_by_width_cpu_data[0]];
// start of reorganizing the input
std::vector<size_t> new_offset;
new_offset.resize(max_width + 1);
new_offset[0] = 0;
int j = batch - 1;
int last_width = 0;
int sub_row = 0;
int sub_col = 0;
for (int i = 1; i <= max_width;) {
for (int k = j; k >= 0; --k) {
if (width_data[idx_sorted_by_width_cpu_data[k]] > last_width) {
sub_row = width_data[idx_sorted_by_width_cpu_data[k]] - last_width;
sub_col = k + 1;
for (int s = 0; s < sub_row; s++) {
new_offset[i] = new_offset[i - 1] + sub_col;
i++;
}
// move on
last_width = width_data[idx_sorted_by_width_cpu_data[k]];
j = k - 1;
break;
}
}
}
// copying to the reorganized buffer
auto* _layout_input = new Tensor();
auto* _layout_input_gpu = param.layout_input;
if (_input->dims().size() == 1) {
// _layout_input.reshape_batch_sequence({dim0}, new_offset);
LOG(FATAL) << "_input->dims().size() = 1, error.";
} else {
// _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset);
LoD new_lod;
new_lod.push_back(new_offset);
_layout_input->set_lod(new_lod);
_layout_input->Resize({dim0, dim1});
_layout_input_gpu->set_lod(new_lod);
_layout_input_gpu->Resize({dim0, dim1});
}
auto* new_emb = _layout_input->mutable_data<float>();
auto* input_cpu = new Tensor();
input_cpu->Resize(_input->dims());
auto* input_cpu_data = input_cpu->mutable_data<float>();
TargetW::MemcpyAsync(input_cpu_data,
_input->data<float>(),
_input->numel() * sizeof(float),
IoDirection::DtoH,
cuda_stream);
for (int i = 0; i < max_width; i++) {
int w = new_offset[i + 1] - new_offset[i];
auto* emb_start = new_emb + dim1 * new_offset[i];
for (int j = 0; j < w; ++j) {
memcpy(emb_start + dim1 * j,
input_cpu_data + dim1 * offset[idx_sorted_by_width_cpu_data[j]] +
dim1 * i,
dim1 * sizeof(float));
}
}
auto* _layout_input_gpu_data =
_layout_input_gpu->mutable_data<float>(TARGET(kCUDA));
TargetW::MemcpyAsync(_layout_input_gpu_data,
new_emb,
_layout_input->numel() * sizeof(float),
IoDirection::HtoD,
cuda_stream);
delete _layout_input;
delete input_cpu;
}
void SearchGrnnCompute::CopyBack(float* from, float* to, int step) {
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto* _input = param.x;
auto* _layout_input = param.layout_input;
const auto& offset = _input->lod()[0];
const auto& new_offset = _layout_input->lod()[0];
const auto* idx_sorted_by_width_cpu_data =
idx_sorted_by_width_cpu->data<int>();
for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) {
int w = new_offset[i + 1] - new_offset[i];
for (int j = 0; j < w; j++) {
TargetW::MemcpyAsync(
to + step * (offset[idx_sorted_by_width_cpu_data[j]] + i),
from + (new_offset[i] + j) * step,
step * sizeof(float),
IoDirection::DtoD,
stream);
}
}
}
void SearchGrnnCompute::Run() {
CHECK(ctx_) << "running context should be set first";
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto* bottom = param.x;
auto* wi = param.wi;
auto* wh = param.wh;
auto* top = param.out;
auto* _buffer = param.tmp_buffer;
int _cap_h = param.num_hidden;
int _cap_e = param.num_input;
int _cap_l = bottom->dims()[0];
int batch = bottom->lod()[0].size() - 1;
const auto& offset = bottom->lod()[0];
LoD top_lod;
top_lod.push_back(offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{_cap_l, _cap_h};
top->Resize(top_dims_vec);
auto* top_hidden = top->mutable_data<float>(TARGET(kCUDA));
const auto* dense_e2h = wi->data<float>();
const auto* dense_h2h = wh->data<float>();
const auto* e2h = dense_e2h;
const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h;
const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h;
const auto* h2h = dense_h2h;
const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h;
const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h;
PrepareLayout(bottom);
auto* _layout_input = param.layout_input;
auto* new_emb = _layout_input->data<float>();
const auto& new_offset = _layout_input->lod()[0];
int max_width = _layout_input->lod()[0].size() - 1;
// this buffer is used for book keeping info which will be used in bp
// buffer also needed in bp, so make it larger
_buffer->Resize({20, _cap_l, _cap_h});
auto* buffer_data = _buffer->mutable_data<float>(TARGET(kCUDA));
auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h;
auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h;
auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h;
auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h;
auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h;
auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h;
auto* r = buffer_data + 6 * _cap_l * _cap_h;
auto* z = buffer_data + 7 * _cap_l * _cap_h;
auto* tilde = buffer_data + 8 * _cap_l * _cap_h;
// the internal hidden
auto* hidden = buffer_data + 19 * _cap_l * _cap_h;
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context);
gemm_impl_->run(1.0f, 0.0f, new_emb, e2h, w_x_e, &context);
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context);
gemm_impl_->run(1.0f, 0.0f, new_emb, e2hr, wr_x_e, &context);
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context);
gemm_impl_->run(1.0f, 0.0f, new_emb, e2hz, wz_x_e, &context);
// precompute hidden0
int num = batch * _cap_h;
int threads = 512;
int blocks = (num + threads - 1) / threads;
PreComputeKernel<<<blocks, threads, 0, stream>>>(
num, w_x_e, wz_x_e, tilde, z, hidden);
// recurrence
for (int i = 1; i < max_width; i++) {
int w_tm1 = new_offset[i] - new_offset[i - 1];
int w = new_offset[i + 1] - new_offset[i];
// precompute hidden i-1 to hidden i
auto* htm1 = hidden + new_offset[i - 1] * _cap_h;
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context);
gemm_impl_->run(
1.0f, 0.0f, htm1, h2h, u_x_h + new_offset[i] * _cap_h, &context);
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context);
gemm_impl_->run(
1.0f, 0.0f, htm1, h2hr, ur_x_h + new_offset[i] * _cap_h, &context);
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context);
gemm_impl_->run(
1.0f, 0.0f, htm1, h2hz, uz_x_h + new_offset[i] * _cap_h, &context);
// compute the gate and hidden
int start = new_offset[i] * _cap_h;
int end = (new_offset[i] + w) * _cap_h;
PostComputeKernel<<<blocks, threads, 0, stream>>>(start,
end,
_cap_h,
w_tm1,
wr_x_e,
ur_x_h,
wz_x_e,
uz_x_h,
w_x_e,
u_x_h,
r,
z,
tilde,
hidden);
}
CopyBack(hidden, top_hidden, _cap_h);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_grnn,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchGrnnCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Wi",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Wh",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("tmp_buffer",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("idx_sorted_by_width",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("layout_input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SearchGrnnCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::SearchGrnnParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
void PrepareForRun() override;
void Run() override;
virtual ~SearchGrnnCompute() = default;
private:
std::shared_ptr<Tensor> idx_sorted_by_width_cpu;
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
void PrepareLayout(const Tensor* input);
void CopyBack(float* from, float* to, int step);
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_grnn_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
TEST(search_grnn, normal) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
SearchGrnnCompute kernel;
operators::SearchGrnnParam param;
int num_input = 6;
int num_hidden = 6;
int num_batch = 3;
Tensor x, wi, wh, out, idx_sorted_by_width, layout_input, tmp_buffer;
x.Resize({num_batch, num_input});
wi.Resize({3, num_hidden, num_input});
wh.Resize({3, num_hidden, num_hidden});
LoD x_lod{};
x_lod.push_back({0, 1, 3});
x.set_lod(x_lod);
Tensor x_cpu, wi_cpu, wh_cpu, out_cpu, layout_input_cpu, tmp_buffer_cpu;
x_cpu.Resize({num_batch, num_input});
wi_cpu.Resize({3, num_hidden, num_input});
wh_cpu.Resize({3, num_hidden, num_hidden});
out_cpu.Resize({num_batch, num_hidden});
layout_input_cpu.Resize({num_batch, num_input});
tmp_buffer_cpu.Resize({20, num_batch, num_hidden});
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = static_cast<float>(i);
}
auto* wi_cpu_data = wi_cpu.mutable_data<float>();
for (int i = 0; i < wi_cpu.numel(); ++i) {
wi_cpu_data[i] = static_cast<float>(i);
}
auto* wh_cpu_data = wh_cpu.mutable_data<float>();
for (int i = 0; i < wh_cpu.numel(); ++i) {
wh_cpu_data[i] = static_cast<float>(i);
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
wi.Assign<float, lite::DDim, TARGET(kCUDA)>(wi_cpu_data, wi_cpu.dims());
wh.Assign<float, lite::DDim, TARGET(kCUDA)>(wh_cpu_data, wh_cpu.dims());
param.x = &x;
param.wi = &wi;
param.wh = &wh;
param.out = &out;
param.idx_sorted_by_width = &idx_sorted_by_width;
param.layout_input = &layout_input;
param.tmp_buffer = &tmp_buffer;
param.num_input = num_input;
param.num_hidden = num_hidden;
kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
kernel.SetContext(std::move(ctx));
kernel.Launch();
cudaDeviceSynchronize();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
LOG(INFO) << "out_data:";
for (int i = 0; i < out.numel(); i++) {
// EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-5);
LOG(INFO) << out_cpu_data[i];
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -46,7 +46,7 @@ add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_x86 X86 basic SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps} blas math_function)
add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps} blas math_function)
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
add_kernel(attention_padding_mask_compute_x86 X86 basic SRCS attention_padding_mask_compute.cc DEPS ${lite_kernel_deps})
......@@ -89,12 +89,12 @@ lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86)
lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86)
lite_cc_test(test_search_grnn_compute_x86 SRCS search_grnn_compute_test.cc DEPS search_grnn_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
lite_cc_test(test_stack_compute_x86 SRCS stack_compute_test.cc DEPS stack_compute_x86)
lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86)
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_grnn_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(search_grnn_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_grnn");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
TEST(search_grnn_x86, init) {
SearchGrnnCompute<float> ssdc;
ASSERT_EQ(ssdc.precision(), PRECISION(kFloat));
ASSERT_EQ(ssdc.target(), TARGET(kX86));
}
TEST(search_grnn_x86, run_test) {
int num_input = 128;
int num_hidden = 128;
int num_batch = 3;
lite::Tensor x, wi, wh, out, idx_sorted_by_width, layout_input, tmp_buffer;
x.Resize({num_batch, num_input});
wi.Resize({3, num_hidden, num_input});
wh.Resize({3, num_hidden, num_hidden});
// out.Resize({num_batch, num_hidden});
LoD x_lod{};
x_lod.push_back({0, 1, 3});
x.set_lod(x_lod);
auto* x_data = x.mutable_data<float>();
for (int64_t i = 0; i < x.numel(); i++) {
x_data[i] = static_cast<float>(i);
}
auto* wi_data = wi.mutable_data<float>();
for (int64_t i = 0; i < wi.numel(); i++) {
wi_data[i] = static_cast<float>(i);
}
auto* wh_data = wh.mutable_data<float>();
for (int64_t i = 0; i < wh.numel(); i++) {
wh_data[i] = static_cast<float>(i);
}
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
operators::SearchGrnnParam param;
param.x = &x;
param.wi = &wi;
param.wh = &wh;
param.out = &out;
param.idx_sorted_by_width = &idx_sorted_by_width;
param.layout_input = &layout_input;
param.tmp_buffer = &tmp_buffer;
param.num_input = num_input;
param.num_hidden = num_hidden;
SearchGrnnCompute<float> sgc;
sgc.SetContext(std::move(ctx));
sgc.SetParam(param);
sgc.Run();
// std::vector<float> ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19};
auto* out_data = out.mutable_data<float>();
LOG(INFO) << out.numel();
for (int i = 0; i < out.numel(); i++) {
// EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(search_grnn, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册