From 7f408ee8cae3249ce2829eb70dc3d6da71ba52b4 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 15 Nov 2019 22:22:01 +0800 Subject: [PATCH] Add content-dnn ops (#2429) * add search_seq_depadding x86 and cuda * add match_matrix_tensor x86 * add search_grnn x86, no test --- lite/kernels/cuda/CMakeLists.txt | 2 + .../cuda/search_seq_depadding_compute.cu | 110 ++++++ .../cuda/search_seq_depadding_compute.h | 39 ++ .../cuda/search_seq_depadding_compute_test.cc | 88 +++++ lite/kernels/x86/CMakeLists.txt | 5 + .../x86/match_matrix_tensor_compute.cc | 119 +++++++ .../kernels/x86/match_matrix_tensor_compute.h | 42 +++ .../x86/match_matrix_tensor_compute_test.cc | 116 ++++++ lite/kernels/x86/search_grnn_compute.cc | 332 ++++++++++++++++++ lite/kernels/x86/search_grnn_compute.h | 43 +++ .../x86/search_seq_depadding_compute.cc | 76 ++++ .../x86/search_seq_depadding_compute.h | 40 +++ .../x86/search_seq_depadding_compute_test.cc | 83 +++++ lite/operators/CMakeLists.txt | 3 + lite/operators/match_matrix_tensor_op.cc | 102 ++++++ lite/operators/match_matrix_tensor_op.h | 49 +++ lite/operators/op_params.h | 32 ++ lite/operators/search_grnn_op.cc | 94 +++++ lite/operators/search_grnn_op.h | 48 +++ lite/operators/search_seq_depadding_op.cc | 71 ++++ lite/operators/search_seq_depadding_op.h | 49 +++ 21 files changed, 1543 insertions(+) create mode 100644 lite/kernels/cuda/search_seq_depadding_compute.cu create mode 100644 lite/kernels/cuda/search_seq_depadding_compute.h create mode 100644 lite/kernels/cuda/search_seq_depadding_compute_test.cc create mode 100644 lite/kernels/x86/match_matrix_tensor_compute.cc create mode 100644 lite/kernels/x86/match_matrix_tensor_compute.h create mode 100644 lite/kernels/x86/match_matrix_tensor_compute_test.cc create mode 100644 lite/kernels/x86/search_grnn_compute.cc create mode 100644 lite/kernels/x86/search_grnn_compute.h create mode 100644 lite/kernels/x86/search_seq_depadding_compute.cc create mode 100644 lite/kernels/x86/search_seq_depadding_compute.h create mode 100644 lite/kernels/x86/search_seq_depadding_compute_test.cc create mode 100644 lite/operators/match_matrix_tensor_op.cc create mode 100644 lite/operators/match_matrix_tensor_op.h create mode 100644 lite/operators/search_grnn_op.cc create mode 100644 lite/operators/search_grnn_op.h create mode 100644 lite/operators/search_seq_depadding_op.cc create mode 100644 lite/operators/search_seq_depadding_op.h diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 24e306b420..04b1902b70 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) @@ -40,6 +41,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) +nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/cuda/search_seq_depadding_compute.cu b/lite/kernels/cuda/search_seq_depadding_compute.cu new file mode 100644 index 0000000000..179041596b --- /dev/null +++ b/lite/kernels/cuda/search_seq_depadding_compute.cu @@ -0,0 +1,110 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/search_seq_depadding_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void ker_sequence_depadding_fwd(Dtype* out_data, + const Dtype* in_data, + const int* seq_id_map, + const int seq_num, + const int max_len, + const int emb_size, + const int count) { + CUDA_KERNEL_LOOP(tid, count) { + int emb_id = tid % emb_size; + int word_id = tid / emb_size; + int seq_id = seq_id_map[word_id]; + out_data[tid] = in_data[seq_id * emb_size + emb_id]; + } +} + +void SearchSeqDepaddingCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto cuda_stream = ctx.exec_stream(); + + auto* pad = param.pad; + auto* src = param.src; + auto* out = param.out; + + auto* in_data = pad->data(); + auto* out_data = out->mutable_data(TARGET(kCUDA)); + const int count = out->numel(); + + const auto& pad_seq_offset = pad->lod()[0]; + const auto& src_seq_offset = src->lod()[0]; + int max_len = pad_seq_offset[1]; + int seq_num = pad_seq_offset.size() - 1; + int emb_size = pad->dims()[1]; + + std::vector seq_id_map; + for (int i = 0; i < seq_num; i++) { + int cur_len = src_seq_offset[i + 1] - src_seq_offset[i]; + for (int j = 0; j < cur_len; j++) { + seq_id_map.push_back(i * max_len + j); + } + } + + int map_size = seq_id_map.size(); + seq_id_map_tensor.Resize({map_size, 1, 1, 1}); + int* seq_id_map_data = seq_id_map_tensor.mutable_data(TARGET(kCUDA)); + TargetW::MemcpyAsync(seq_id_map_data, + &seq_id_map[0], + seq_id_map.size() * sizeof(int), + IoDirection::HtoD, + cuda_stream); + + int threads = 512; + ker_sequence_depadding_fwd<<>>( + out_data, in_data, seq_id_map_data, seq_num, max_len, emb_size, count); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_seq_depadding, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SearchSeqDepaddingCompute, + def) + .BindInput("Src", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Pad", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/search_seq_depadding_compute.h b/lite/kernels/cuda/search_seq_depadding_compute.h new file mode 100644 index 0000000000..a06f39bee2 --- /dev/null +++ b/lite/kernels/cuda/search_seq_depadding_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SearchSeqDepaddingCompute + : public KernelLite { + public: + using param_t = operators::SearchSeqDepaddingParam; + using TargetW = TargetWrapper; + + void Run() override; + virtual ~SearchSeqDepaddingCompute() = default; + + private: + Tensor seq_id_map_tensor; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_seq_depadding_compute_test.cc b/lite/kernels/cuda/search_seq_depadding_compute_test.cc new file mode 100644 index 0000000000..9c23ff14ab --- /dev/null +++ b/lite/kernels/cuda/search_seq_depadding_compute_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_seq_depadding_compute.h" +#include +#include +#include +#include +#include "lite/api/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +TEST(search_seq_depadding, normal) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + SearchSeqDepaddingCompute kernel; + operators::SearchSeqDepaddingParam param; + + Tensor pad, src, out; + pad.Resize({2 * 3, 4}); + src.Resize({3, 1}); + out.Resize({3, 4}); + LoD pad_lod{}; + pad_lod.push_back({0, 4, 6}); + pad.set_lod(pad_lod); + LoD src_lod{}; + src_lod.push_back({0, 2, 3}); + src.set_lod(src_lod); + + Tensor pad_cpu, src_cpu, out_cpu; + pad_cpu.Resize({2 * 3, 4}); + src_cpu.Resize({3, 1}); + out_cpu.Resize({3, 4}); + + auto* pad_cpu_data = pad_cpu.mutable_data(); + auto* src_cpu_data = src_cpu.mutable_data(); + for (int i = 0; i < pad_cpu.numel(); ++i) { + pad_cpu_data[i] = static_cast(i); + } + + pad.Assign(pad_cpu_data, pad_cpu.dims()); + src.Assign(src_cpu_data, src_cpu.dims()); + + param.pad = &pad; + param.src = &src; + param.out = &out; + kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + kernel.SetContext(std::move(ctx)); + kernel.Launch(); + cudaDeviceSynchronize(); + + auto* out_cpu_data = out_cpu.mutable_data(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + + std::vector ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19}; + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-5); + // LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index b95e541e83..8bdc2a17c1 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -41,6 +41,9 @@ add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${li add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(match_matrix_tensor_compute_x86 X86 basic SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps} blas math_function) +add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type) @@ -72,9 +75,11 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) +lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86) if(LITE_BUILD_EXTRA) lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86) endif() lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) +lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86) lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) diff --git a/lite/kernels/x86/match_matrix_tensor_compute.cc b/lite/kernels/x86/match_matrix_tensor_compute.cc new file mode 100644 index 0000000000..a0b4160c3a --- /dev/null +++ b/lite/kernels/x86/match_matrix_tensor_compute.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/match_matrix_tensor_compute.h" +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +void MatchMatrixTensorCompute::Run() { + auto& context = ctx_->As(); + auto& param = this->Param(); + auto* x = param.x; + auto* w = param.w; + auto* y = param.y; + auto* out = param.out; + auto* tmp = param.tmp; + int dim_t = param.dim_t; + int dim_in = x->dims()[1]; + + const auto& offset_l = x->lod()[0]; + const auto& offset_r = y->lod()[0]; + + std::vector top_offset; + int top_size = 0; + top_offset.push_back(top_size); + for (size_t b = 0; b < x->lod()[0].size() - 1; b++) { + int len_l = offset_l[b + 1] - offset_l[b]; + int len_r = offset_r[b + 1] - offset_r[b]; + top_size += dim_t * len_l * len_r; + top_offset.push_back(top_size); + } + + auto* bottom_l_data = x->template data(); + auto* bottom_r_data = y->template data(); + auto* t_data = w->template data(); + auto* out_data = out->template mutable_data(); + auto* bottom_l_trans_data = tmp->template mutable_data(); + memset(out_data, 0.0, out->dims()[0] * out->dims()[1] * sizeof(T)); + memset(bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); + + auto blas = lite::x86::math::GetBlas(context); + blas.GEMM(CblasNoTrans, + CblasNoTrans, + x->dims()[0], + dim_t * dim_in, + dim_in, + 1.0f, + bottom_l_data, + dim_in, + t_data, + dim_t * dim_in, + 0.0f, + bottom_l_trans_data, + dim_t * dim_in); + + for (size_t b = 0; b < x->lod()[0].size() - 1; b++) { + for (int t = 0; t < dim_t; t++) { + int len_l = offset_l[b + 1] - offset_l[b]; + int len_r = offset_r[b + 1] - offset_r[b]; + auto* top_data = out_data + top_offset[b] + t * len_l * len_r; + const auto* l_t_data = + bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; + const auto* r_data = bottom_r_data + offset_r[b] * dim_in; + + auto blas = lite::x86::math::GetBlas(context); + blas.GEMM(CblasNoTrans, + CblasTrans, + len_l, + len_r, + dim_in, + 1.0f, + l_t_data, + dim_t * dim_in, + r_data, + dim_in, + 0.0f, + top_data, + len_r); + } + } + + LoD out_lod; + out_lod.push_back(top_offset); + out->set_lod(out_lod); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + match_matrix_tensor, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::MatchMatrixTensorCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Tmp", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/match_matrix_tensor_compute.h b/lite/kernels/x86/match_matrix_tensor_compute.h new file mode 100644 index 0000000000..6189676fd8 --- /dev/null +++ b/lite/kernels/x86/match_matrix_tensor_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class MatchMatrixTensorCompute + : public KernelLite { + public: + using param_t = operators::MatchMatrixTensorParam; + + void Run() override; + + virtual ~MatchMatrixTensorCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/match_matrix_tensor_compute_test.cc b/lite/kernels/x86/match_matrix_tensor_compute_test.cc new file mode 100644 index 0000000000..0c3f3ad509 --- /dev/null +++ b/lite/kernels/x86/match_matrix_tensor_compute_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/match_matrix_tensor_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(match_matrix_tensor_x86, retrive_op) { + auto kernel = + KernelRegistry::Global().Create( + "match_matrix_tensor"); + ASSERT_FALSE(kernel.empty()); + ASSERT_TRUE(kernel.front()); +} + +TEST(match_matrix_tensor_x86, init) { + MatchMatrixTensorCompute mmtc; + ASSERT_EQ(mmtc.precision(), PRECISION(kFloat)); + ASSERT_EQ(mmtc.target(), TARGET(kX86)); +} + +TEST(match_matrix_tensor_x86, run_test) { + int ix = 5, iy = 4, h = 2, dim_t = 2; + lite::Tensor x, w, y, out, tmp; + x.Resize({ix, h}); + w.Resize({h, dim_t, h}); + y.Resize({iy, h}); + out.Resize({18, 1}); + tmp.Resize({20, 1}); + + LoD x_lod{}; + x_lod.push_back({0, 2, 5}); + x.set_lod(x_lod); + LoD y_lod{}; + y_lod.push_back({0, 3, 4}); + y.set_lod(y_lod); + + auto* x_data = x.mutable_data(); + for (int64_t i = 0; i < x.numel(); i++) { + x_data[i] = static_cast(i); + } + auto* y_data = y.mutable_data(); + for (int64_t i = 0; i < y.numel(); i++) { + y_data[i] = static_cast(i); + } + auto* w_data = w.mutable_data(); + for (int64_t i = 0; i < w.numel(); i++) { + w_data[i] = static_cast(i); + } + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + MatchMatrixTensorCompute mmtc; + mmtc.SetContext(std::move(ctx)); + + operators::MatchMatrixTensorParam param; + param.x = &x; + param.w = &w; + param.y = &y; + param.dim_t = dim_t; + param.out = &out; + param.tmp = &tmp; + + mmtc.SetParam(param); + mmtc.Run(); + + std::vector ref_results = {5, + 23, + 41, + 17, + 75, + 133, + 7, + 33, + 59, + 27, + 125, + 223, + 323, + 455, + 587, + 557, + 793, + 1029}; + auto* out_data = out.mutable_data(); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + // LOG(INFO) << out_data[i]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(match_matrix_tensor, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/search_grnn_compute.cc b/lite/kernels/x86/search_grnn_compute.cc new file mode 100644 index 0000000000..95839ba71b --- /dev/null +++ b/lite/kernels/x86/search_grnn_compute.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_grnn_compute.h" +#include +#include +#include "lite/backends/x86/math/blas.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +T sigmoid(T z) { + return 1 / (1 + std::exp(-z)); +} + +template +void CallGemm(const lite::x86::math::BlasT& blas, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const T* B, + const T beta, + T* C) { + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); +} + +template +void SearchGrnnCompute::PrepareLayout(const Tensor* input_blob) { + auto& param = this->Param(); + auto* _idx_sorted_by_width = param.idx_sorted_by_width; + auto* _layout_input = param.layout_input; + auto* _input = input_blob; + + // usually total length + int dim0 = _input->dims()[0]; + // if it is id only sequence + int dim1 = 1; + // if its a embedding like sequence (dim1 would be embedding_size) + if (_input->dims().size() > 1) { + dim1 = _input->dims()[1]; + } + + int batch = _input->lod()[0].size() - 1; + auto& offset = _input->lod()[0]; + + Tensor _width; + _width.Resize({batch}); + _idx_sorted_by_width->Resize({batch}); + int* width_data = _width.template mutable_data(); + int* idx_sorted_by_width_data = + _idx_sorted_by_width->template mutable_data(); + // sort sequence by width (descending) and find the largest width in the + // batch + for (int i = 0; i < batch; i++) { + width_data[i] = offset[i + 1] - offset[i]; + idx_sorted_by_width_data[i] = i; + } + std::sort(idx_sorted_by_width_data, + idx_sorted_by_width_data + batch, + [&_width](int a, int b) { + return _width.template data()[a] > + _width.template data()[b]; + }); + int max_width = width_data[idx_sorted_by_width_data[0]]; + + // start of reorganizing the input + std::vector new_offset; + new_offset.resize(max_width + 1); + + new_offset[0] = 0; + int j = batch - 1; + int last_width = 0; + int sub_row = 0; + int sub_col = 0; + + for (int i = 1; i <= max_width;) { + for (int k = j; k >= 0; --k) { + if (width_data[idx_sorted_by_width_data[k]] > last_width) { + sub_row = width_data[idx_sorted_by_width_data[k]] - last_width; + sub_col = k + 1; + + for (int s = 0; s < sub_row; s++) { + new_offset[i] = new_offset[i - 1] + sub_col; + i++; + } + // move on + last_width = width_data[idx_sorted_by_width_data[k]]; + j = k - 1; + break; + } + } + } + + // copying to the reorganized buffer + if (_input->dims().size() == 1) { + // _layout_input.reshape_batch_sequence({dim0}, new_offset); + LOG(FATAL) << "_input->dims().size() = 1, error."; + } else { + // _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset); + LoD new_lod; + new_lod.push_back(new_offset); + _layout_input->set_lod(new_lod); + _layout_input->Resize({dim0, dim1}); + } + + auto* new_emb = _layout_input->template mutable_data(); + for (int i = 0; i < max_width; i++) { + int w = new_offset[i + 1] - new_offset[i]; + auto* emb_start = new_emb + dim1 * new_offset[i]; + for (int j = 0; j < w; ++j) { + memcpy(emb_start + dim1 * j, + _input->template data() + + dim1 * offset[idx_sorted_by_width_data[j]] + dim1 * i, + dim1 * sizeof(T)); + } + } +} + +template +void SearchGrnnCompute::CopyBack(T* from, T* to, int step) { + auto& param = this->Param(); + auto* _input = param.x; + auto* _layout_input = param.layout_input; + auto* _idx_sorted_by_width = param.idx_sorted_by_width; + + const auto& offset = _input->lod()[0]; + const auto& new_offset = _layout_input->lod()[0]; + const auto* idx_sorted_by_width_data = + _idx_sorted_by_width->template data(); + for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) { + int w = new_offset[i + 1] - new_offset[i]; + for (int j = 0; j < w; j++) { + memcpy(to + step * (offset[idx_sorted_by_width_data[j]] + i), + from + (new_offset[i] + j) * step, + step * sizeof(T)); + } + } +} + +template +void SearchGrnnCompute::Run() { + auto& context = ctx_->As(); + auto& param = this->Param(); + auto* bottom = param.x; + auto* wi = param.wi; + auto* wh = param.wh; + auto* top = param.out; + auto* _buffer = param.tmp_buffer; + int _cap_h = param.num_hidden; + int _cap_e = param.num_input; + + int _cap_l = bottom->dims()[0]; + int batch = bottom->lod()[0].size() - 1; + + const auto& offset = bottom->lod()[0]; + LoD top_lod; + top_lod.push_back(offset); + top->set_lod(top_lod); + std::vector top_dims_vec{_cap_l, _cap_h}; + top->Resize(top_dims_vec); + auto* top_hidden = top->template mutable_data(); + + const auto* dense_e2h = wi->template data(); + const auto* dense_h2h = wh->template data(); + + const auto* e2h = dense_e2h; + const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h; + const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h; + const auto* h2h = dense_h2h; + const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h; + const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h; + + PrepareLayout(bottom); + + auto* _layout_input = param.layout_input; + auto* new_emb = _layout_input->template mutable_data(); + const auto& new_offset = _layout_input->lod()[0]; + int max_width = _layout_input->lod()[0].size() - 1; + + // this buffer is used for book keeping info which will be used in bp + // buffer also needed in bp, so make it larger + _buffer->Resize({20, _cap_l, _cap_h}); + auto* buffer_data = _buffer->template mutable_data(); + auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h; + auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h; + auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h; + auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h; + auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h; + auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h; + auto* r = buffer_data + 6 * _cap_l * _cap_h; + auto* z = buffer_data + 7 * _cap_l * _cap_h; + auto* tilde = buffer_data + 8 * _cap_l * _cap_h; + // the internal hidden + auto* hidden = buffer_data + 19 * _cap_l * _cap_h; + + auto blas = lite::x86::math::GetBlas(context); + CallGemm(blas, + CblasNoTrans, + CblasTrans, + _cap_l, + _cap_h, + _cap_e, + 1.0f, + new_emb, + e2h, + 0.0f, + w_x_e); + CallGemm(blas, + CblasNoTrans, + CblasTrans, + _cap_l, + _cap_h, + _cap_e, + 1.0f, + new_emb, + e2hr, + 0.0f, + wr_x_e); + CallGemm(blas, + CblasNoTrans, + CblasTrans, + _cap_l, + _cap_h, + _cap_e, + 1.0f, + new_emb, + e2hz, + 0.0f, + wz_x_e); + + // precompute hidden0 + for (int i = 0; i < batch * _cap_h; i++) { + tilde[i] = std::tanh(w_x_e[i]); + z[i] = sigmoid(wz_x_e[i]); + hidden[i] = (1. - z[i]) * tilde[i]; + } + + // recurrence + for (int i = 1; i < max_width; i++) { + int w_tm1 = new_offset[i] - new_offset[i - 1]; + int w = new_offset[i + 1] - new_offset[i]; + + // precompute hidden i-1 to hidden i + auto* htm1 = hidden + new_offset[i - 1] * _cap_h; + + CallGemm(blas, + CblasNoTrans, + CblasTrans, + w, + _cap_h, + _cap_h, + 1.0f, + htm1, + h2h, + 0.0f, + u_x_h + new_offset[i] * _cap_h); + CallGemm(blas, + CblasNoTrans, + CblasTrans, + w, + _cap_h, + _cap_h, + 1.0f, + htm1, + h2hr, + 0.0f, + ur_x_h + new_offset[i] * _cap_h); + CallGemm(blas, + CblasNoTrans, + CblasTrans, + w, + _cap_h, + _cap_h, + 1.0f, + htm1, + h2hz, + 0.0f, + uz_x_h + new_offset[i] * _cap_h); + + // compute the gate and hidden + for (size_t j = new_offset[i] * _cap_h; j < (new_offset[i] + w) * _cap_h; + j++) { + r[j] = sigmoid(wr_x_e[j] + ur_x_h[j]); + z[j] = sigmoid(wz_x_e[j] + uz_x_h[j]); + tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]); + hidden[j] = z[j] * hidden[j - _cap_h * w_tm1] + (1.0 - z[j]) * tilde[j]; + } + } + + CopyBack(hidden, top_hidden, _cap_h); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_grnn, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SearchGrnnCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Wi", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Wh", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("tmp_buffer", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("idx_sorted_by_width", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindOutput("layout_input", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/search_grnn_compute.h b/lite/kernels/x86/search_grnn_compute.h new file mode 100644 index 0000000000..5082faf45b --- /dev/null +++ b/lite/kernels/x86/search_grnn_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SearchGrnnCompute : public KernelLite { + public: + using param_t = operators::SearchGrnnParam; + + void Run() override; + + virtual ~SearchGrnnCompute() = default; + + private: + void PrepareLayout(const Tensor* input); + void CopyBack(T* from, T* to, int step); +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/search_seq_depadding_compute.cc b/lite/kernels/x86/search_seq_depadding_compute.cc new file mode 100644 index 0000000000..db1816fb48 --- /dev/null +++ b/lite/kernels/x86/search_seq_depadding_compute.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_seq_depadding_compute.h" +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +void SearchSeqDepaddingCompute::Run() { + auto& param = this->Param(); + auto* pad = param.pad; + auto* src = param.src; + auto* out = param.out; + + const int pad_batch = pad->lod()[0].size() - 1; + const int src_batch = src->lod()[0].size() - 1; + if (pad_batch % src_batch != 0) { + LOG(FATAL) << "Mismatch batch size."; + } + + const auto& pad_offset = pad->lod()[0]; + const int pad_cap_e = pad->dims()[1]; + const auto& src_offset = src->lod()[0]; + const int src_cap_l = src->dims()[0]; + + LoD out_lod; + out_lod.push_back(src_offset); + out->set_lod(out_lod); + out->Resize({src_cap_l, pad_cap_e}); + + const auto* pad_data = pad->template data(); + auto* out_data = out->template mutable_data(); + for (int i = 0; i < src_batch; ++i) { + const int src_i_l = src_offset[i + 1] - src_offset[i]; + const int pad_i_l = pad_offset[i + 1] - pad_offset[i]; + if (pad_i_l < src_i_l) { + LOG(FATAL) + << "the length of padding seq input is less than source seq input."; + } + memcpy(out_data + src_offset[i] * pad_cap_e, + pad_data + pad_offset[i] * pad_cap_e, + src_i_l * pad_cap_e * sizeof(T)); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + search_seq_depadding, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SearchSeqDepaddingCompute, + def) + .BindInput("Pad", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Src", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/search_seq_depadding_compute.h b/lite/kernels/x86/search_seq_depadding_compute.h new file mode 100644 index 0000000000..e48fa92723 --- /dev/null +++ b/lite/kernels/x86/search_seq_depadding_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SearchSeqDepaddingCompute + : public KernelLite { + public: + using param_t = operators::SearchSeqDepaddingParam; + + void Run() override; + + virtual ~SearchSeqDepaddingCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/search_seq_depadding_compute_test.cc b/lite/kernels/x86/search_seq_depadding_compute_test.cc new file mode 100644 index 0000000000..0d978b35ed --- /dev/null +++ b/lite/kernels/x86/search_seq_depadding_compute_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_seq_depadding_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(search_seq_depadding_x86, retrive_op) { + auto kernel = + KernelRegistry::Global().Create( + "search_seq_depadding"); + ASSERT_FALSE(kernel.empty()); + ASSERT_TRUE(kernel.front()); +} + +TEST(search_seq_depadding_x86, init) { + SearchSeqDepaddingCompute ssdc; + ASSERT_EQ(ssdc.precision(), PRECISION(kFloat)); + ASSERT_EQ(ssdc.target(), TARGET(kX86)); +} + +TEST(search_seq_depadding_x86, run_test) { + lite::Tensor pad, src, out; + pad.Resize({2 * 3, 4}); + src.Resize({3, 1}); + out.Resize({3, 4}); + LoD pad_lod{}; + pad_lod.push_back({0, 4, 6}); + pad.set_lod(pad_lod); + LoD src_lod{}; + src_lod.push_back({0, 2, 3}); + src.set_lod(src_lod); + + auto* pad_data = pad.mutable_data(); + for (int64_t i = 0; i < pad.dims().production(); i++) { + pad_data[i] = static_cast(i); + } + SearchSeqDepaddingCompute ssdc; + operators::SearchSeqDepaddingParam param; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + ssdc.SetContext(std::move(ctx)); + + param.pad = &pad; + param.src = &src; + param.out = &out; + + ssdc.SetParam(param); + ssdc.Run(); + + std::vector ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19}; + auto* out_data = out.mutable_data(); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(search_seq_depadding, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 06961d6103..39eac353af 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -80,6 +80,9 @@ add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wis add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) +add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS}) +add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS}) +add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/match_matrix_tensor_op.cc b/lite/operators/match_matrix_tensor_op.cc new file mode 100644 index 0000000000..8efc8866d9 --- /dev/null +++ b/lite/operators/match_matrix_tensor_op.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/match_matrix_tensor_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MatchMatrixTensorOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.y); + CHECK_OR_FALSE(param_.w); + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.tmp); + + DDim x_dims = param_.x->dims(); + DDim y_dims = param_.y->dims(); + DDim w_dims = param_.w->dims(); + int dim_t = param_.dim_t; + + CHECK_OR_FALSE(x_dims.size() == 2); + CHECK_OR_FALSE(y_dims.size() == 2); + CHECK_OR_FALSE(w_dims.size() == 3); + CHECK_OR_FALSE(x_dims[1] == w_dims[0] && y_dims[1] == w_dims[2] && + w_dims[1] == dim_t); + + return true; +} + +bool MatchMatrixTensorOpLite::InferShape() const { + const Tensor* x = param_.x; + const Tensor* y = param_.y; + DDim x_dims = param_.x->dims(); + DDim y_dims = param_.y->dims(); + DDim w_dims = param_.w->dims(); + int dim_t = param_.dim_t; + + const auto& x_lod = x->lod(); + CHECK_OR_FALSE(!x_lod.empty()); + const auto& x_lod_0 = x_lod[0]; + CHECK_OR_FALSE(x_lod_0.size() >= 2); + CHECK_OR_FALSE(x_dims[0] == x_lod_0.back()); + + const auto& y_lod = y->lod(); + CHECK_OR_FALSE(!y_lod.empty()); + const auto& y_lod_0 = y_lod[0]; + CHECK_OR_FALSE(y_lod_0.size() >= 2); + CHECK_OR_FALSE(y_dims[0] == y_lod_0.back()); + + CHECK_OR_FALSE(x_lod_0.size() == y_lod_0.size()); + + int out_dim_0 = 0; + for (size_t i = 1; i < x_lod_0.size(); i++) { + int x_len = x_lod_0[i] - x_lod_0[i - 1]; + int y_len = y_lod_0[i] - y_lod_0[i - 1]; + out_dim_0 += (x_len * y_len); + } + out_dim_0 *= dim_t; + int tmp_dim_0 = x_dims[0] * dim_t * x_dims[1]; + + param_.out->Resize({out_dim_0, 1}); + param_.tmp->Resize({tmp_dim_0, 1}); + return true; +} + +bool MatchMatrixTensorOpLite::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + auto x = op_desc.Input("X").front(); + auto w = op_desc.Input("W").front(); + auto y = op_desc.Input("Y").front(); + auto out = op_desc.Output("Out").front(); + auto tmp = op_desc.Output("Tmp").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.w = scope->FindVar(w)->GetMutable(); + param_.y = scope->FindVar(y)->GetMutable(); + param_.out = scope->FindVar(out)->GetMutable(); + param_.tmp = scope->FindVar(tmp)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(match_matrix_tensor, + paddle::lite::operators::MatchMatrixTensorOpLite); diff --git a/lite/operators/match_matrix_tensor_op.h b/lite/operators/match_matrix_tensor_op.h new file mode 100644 index 0000000000..404183ea5b --- /dev/null +++ b/lite/operators/match_matrix_tensor_op.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MatchMatrixTensorOpLite : public OpLite { + public: + MatchMatrixTensorOpLite() {} + + explicit MatchMatrixTensorOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "match_matrix_tensor"; } + + private: + mutable MatchMatrixTensorParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index ac32aed7b1..360bcbecaf 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -951,6 +951,38 @@ struct AssignValueParam { lite::Tensor* Out{}; }; +/// --------------------- match_matrix_tensor operators -------------------- +struct MatchMatrixTensorParam { + const lite::Tensor* x{}; + const lite::Tensor* y{}; + const lite::Tensor* w{}; + lite::Tensor* out{}; + lite::Tensor* tmp{}; + + int dim_t; +}; + +/// --------------------- search_seq_depadding operators -------------------- +struct SearchSeqDepaddingParam { + const lite::Tensor* pad{}; + const lite::Tensor* src{}; + lite::Tensor* out{}; +}; + +/// --------------------- search_grnn operators -------------------- +struct SearchGrnnParam { + const lite::Tensor* x{}; + const lite::Tensor* wi{}; + const lite::Tensor* wh{}; + int num_input; + int num_hidden; + + lite::Tensor* out{}; + lite::Tensor* tmp_buffer{}; + lite::Tensor* idx_sorted_by_width{}; + lite::Tensor* layout_input{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/operators/search_grnn_op.cc b/lite/operators/search_grnn_op.cc new file mode 100644 index 0000000000..b56ae820bf --- /dev/null +++ b/lite/operators/search_grnn_op.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/search_grnn_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SearchGrnnOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.wi); + CHECK_OR_FALSE(param_.wh); + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.tmp_buffer); + CHECK_OR_FALSE(param_.idx_sorted_by_width); + CHECK_OR_FALSE(param_.layout_input); + + int _cap_h = param_.num_hidden; + int _cap_e = param_.num_input; + + const auto& x_dims = param_.x->dims(); + CHECK_OR_FALSE(x_dims.size() == 2); + CHECK_OR_FALSE(x_dims[1] == _cap_e); + + const auto& wi_dims = param_.wi->dims(); + CHECK_OR_FALSE(wi_dims.size() == 3); + CHECK_OR_FALSE(wi_dims[0] == 3); + CHECK_OR_FALSE(wi_dims[1] == _cap_h); + CHECK_OR_FALSE(wi_dims[2] == _cap_e); + + const auto& wh_dims = param_.wh->dims(); + CHECK_OR_FALSE(wh_dims.size() == 3); + CHECK_OR_FALSE(wh_dims[0] == 3); + CHECK_OR_FALSE(wh_dims[1] == _cap_h); + CHECK_OR_FALSE(wh_dims[2] == _cap_h); + + return true; +} + +bool SearchGrnnOpLite::InferShape() const { + const auto& x_dims = param_.x->dims(); + const auto& x_lod = param_.x->lod(); + CHECK_OR_FALSE(!x_lod.empty()); + CHECK_OR_FALSE(x_dims[0] == x_lod[0].back()); + param_.out->set_lod(x_lod); + + return true; +} + +bool SearchGrnnOpLite::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + auto x = op_desc.Input("X").front(); + auto wi = op_desc.Input("Wi").front(); + auto wh = op_desc.Input("Wh").front(); + param_.x = scope->FindVar(x)->GetMutable(); + param_.wi = scope->FindVar(wi)->GetMutable(); + param_.wh = scope->FindVar(wh)->GetMutable(); + + param_.num_input = op_desc.GetAttr("num_input"); + param_.num_hidden = op_desc.GetAttr("num_hidden"); + + auto out = op_desc.Output("Out").front(); + auto tmp_buffer = op_desc.Output("tmp_buffer").front(); + auto idx_sorted_by_width = op_desc.Output("idx_sorted_by_width").front(); + auto layout_input = op_desc.Output("layout_input").front(); + param_.out = scope->FindVar(out)->GetMutable(); + param_.tmp_buffer = scope->FindVar(tmp_buffer)->GetMutable(); + param_.idx_sorted_by_width = + scope->FindVar(idx_sorted_by_width)->GetMutable(); + param_.layout_input = + scope->FindVar(layout_input)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(search_grnn, paddle::lite::operators::SearchGrnnOpLite); diff --git a/lite/operators/search_grnn_op.h b/lite/operators/search_grnn_op.h new file mode 100644 index 0000000000..670af8a6c9 --- /dev/null +++ b/lite/operators/search_grnn_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SearchGrnnOpLite : public OpLite { + public: + SearchGrnnOpLite() {} + + explicit SearchGrnnOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "search_grnn"; } + + private: + mutable SearchGrnnParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/search_seq_depadding_op.cc b/lite/operators/search_seq_depadding_op.cc new file mode 100644 index 0000000000..12d5123e05 --- /dev/null +++ b/lite/operators/search_seq_depadding_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/search_seq_depadding_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SearchSeqDepaddingOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.pad); + CHECK_OR_FALSE(param_.src); + CHECK_OR_FALSE(param_.out); + + DDim pad_dims = param_.pad->dims(); + DDim src_dims = param_.src->dims(); + CHECK_OR_FALSE(pad_dims.size() == 2); + CHECK_OR_FALSE(src_dims.size() == 2); + + const auto& pad_lod = param_.pad->lod(); + CHECK_OR_FALSE(!pad_lod.empty()); + const auto& pad_lod_0 = pad_lod[0]; + CHECK_OR_FALSE(pad_lod_0.size() >= 2); + CHECK_OR_FALSE(pad_dims[0] == pad_lod_0.back()); + + const auto& src_lod = param_.src->lod(); + CHECK_OR_FALSE(!src_lod.empty()); + const auto& src_lod_0 = src_lod[0]; + CHECK_OR_FALSE(src_lod_0.size() >= 2); + CHECK_OR_FALSE(src_dims[0] == src_lod_0.back()); + return true; +} + +bool SearchSeqDepaddingOpLite::InferShape() const { + DDim pad_dims = param_.pad->dims(); + param_.out->Resize({-1, pad_dims[1]}); + return true; +} + +bool SearchSeqDepaddingOpLite::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + auto pad = op_desc.Input("Pad").front(); + auto src = op_desc.Input("Src").front(); + auto out = op_desc.Output("Out").front(); + + param_.pad = scope->FindVar(pad)->GetMutable(); + param_.src = scope->FindVar(src)->GetMutable(); + param_.out = scope->FindVar(out)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(search_seq_depadding, + paddle::lite::operators::SearchSeqDepaddingOpLite); diff --git a/lite/operators/search_seq_depadding_op.h b/lite/operators/search_seq_depadding_op.h new file mode 100644 index 0000000000..445d9e0f3b --- /dev/null +++ b/lite/operators/search_seq_depadding_op.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SearchSeqDepaddingOpLite : public OpLite { + public: + SearchSeqDepaddingOpLite() {} + + explicit SearchSeqDepaddingOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "search_seq_depadding"; } + + private: + mutable SearchSeqDepaddingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab