diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 423fa5355288fcfa0f2cc50f3930cb6e5f013b8d..b774d63b3c513bbe2912b31195022925a574e21a 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) +add_kernel(search_group_padding_compute_x86 X86 extra SRCS search_group_padding_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps}) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) @@ -71,6 +72,7 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86) lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) diff --git a/lite/kernels/x86/search_group_padding_compute.cc b/lite/kernels/x86/search_group_padding_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1847ac9dbafc533b8720ab65e6fa1915d5a136e --- /dev/null +++ b/lite/kernels/x86/search_group_padding_compute.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_group_padding_compute.h" + +REGISTER_LITE_KERNEL( + search_group_padding, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SearchGroupPaddingCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out_emb_padding", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out_new", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out_padding", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/search_group_padding_compute.h b/lite/kernels/x86/search_group_padding_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..17244d15d9124d9d61d1f4fdef4f12590958c0be --- /dev/null +++ b/lite/kernels/x86/search_group_padding_compute.h @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SearchGroupPaddingCompute + : public KernelLite { + public: + using param_t = operators::SearchGroupPaddingParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + auto* bottom0 = param.x; + auto* top0 = param.out_emb_padding; + auto* top1 = param.out_new; + auto* top2 = param.out_padding; + + int _pad_id = param.pad_id; + + int batch = bottom0->lod()[0].size() - 1; + int dim0 = bottom0->dims()[0]; + int dim1 = bottom0->dims()[1]; + + const auto offset = bottom0->lod()[0]; + int max_seq = 0; + for (int i = 0; i < batch; ++i) { + if (offset[i + 1] - offset[i] > max_seq) { + max_seq = offset[i + 1] - offset[i]; + } + } + + std::vector new_offset; + new_offset.resize(batch + 1); + for (int i = 0; i < batch + 1; ++i) { + new_offset[i] = i * max_seq; + } + + // for padding data + lite::LoD top0_lod; + top0_lod.push_back(new_offset); + top0->set_lod(top0_lod); + top0->Resize({batch * max_seq, dim1}); + // for origin input id + // already set by ShareLoD in InferShape + lite::LoD top1_lod; + top1_lod.push_back(offset); + top1->set_lod(top1_lod); + top1->Resize({dim0, 1}); + memset(top1->mutable_data(), + 0, + top1->dims()[0] * top1->dims()[1] * sizeof(T)); + // for padding input id + lite::LoD top2_lod; + top2_lod.push_back(new_offset); + top2->set_lod(top2_lod); + top2->Resize({batch * max_seq, 1}); + // copy data + const auto* bottom_data = bottom0->data(); + auto* top_data = top0->mutable_data(); + auto* top_padding_input_data = top2->mutable_data(); + for (int i = 0; i < batch; i++) { + const int copy_step = offset[i + 1] - offset[i]; + const int start = i * max_seq; + memcpy(top_data + start * dim1, + bottom_data + offset[i] * dim1, + copy_step * dim1 * sizeof(T)); + memset(top_data + (start + copy_step) * dim1, + 0, + (max_seq - copy_step) * dim1 * sizeof(T)); + // for padding input id + memset(top_padding_input_data + start, 0, copy_step * sizeof(T)); + for (int j = start + copy_step; j < start + max_seq; j++) { + top_padding_input_data[j] = static_cast(_pad_id); + } + } + } + + virtual ~SearchGroupPaddingCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/search_group_padding_compute_test.cc b/lite/kernels/x86/search_group_padding_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4c36c2a63488a6bb902a2b8b4ad81fa32b37672 --- /dev/null +++ b/lite/kernels/x86/search_group_padding_compute_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_group_padding_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(search_group_padding_x86, retrieve_op) { + auto search_group_padding = + KernelRegistry::Global().Create( + "search_group_padding"); + ASSERT_FALSE(search_group_padding.empty()); + ASSERT_TRUE(search_group_padding.front()); +} + +TEST(search_group_padding_x86, init) { + SearchGroupPaddingCompute search_group_padding; + ASSERT_EQ(search_group_padding.precision(), PRECISION(kFloat)); + ASSERT_EQ(search_group_padding.target(), TARGET(kX86)); +} + +TEST(search_group_padding_x86, run_test) { + lite::Tensor x, out_emb_padding, out_new, out_padding; + x.Resize({2, 3}); + out_emb_padding.Resize({-1, 3}); + out_new.Resize({2, 1}); + out_padding.Resize({-1, 1}); + LoD x_lod{}; + x_lod.push_back({0, 1}); + x.set_lod(x_lod); + + auto* x_data = x.mutable_data(); + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + SearchGroupPaddingCompute sgp_kernel; + operators::SearchGroupPaddingParam param; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sgp_kernel.SetContext(std::move(ctx)); + + param.x = &x; + param.out_emb_padding = &out_emb_padding; + param.out_new = &out_new; + param.out_padding = &out_padding; + + sgp_kernel.SetParam(param); + sgp_kernel.Run(); + + std::vector out_emb_padding_ref = {0, 1, 2}; + std::vector out_new_ref = {0, 0}; + std::vector out_padding_ref = {0}; + auto* out_emb_padding_data = out_emb_padding.mutable_data(); + auto* out_new_data = out_new.mutable_data(); + auto* out_padding_data = out_padding.mutable_data(); + for (int i = 0; i < out_emb_padding.dims().production(); i++) { + EXPECT_NEAR(out_emb_padding_data[i], out_emb_padding_ref[i], 1e-5); + } + for (int i = 0; i < out_new.dims().production(); i++) { + EXPECT_NEAR(out_new_data[i], out_new_ref[i], 1e-5); + } + for (int i = 0; i < out_padding.dims().production(); i++) { + EXPECT_NEAR(out_padding_data[i], out_padding_ref[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(search_group_padding, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index bc48f853c6f7bc523faf431089ceebe0b1044301..19d6871921d10d26b0a1c001abeda19308fc869e 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -7,6 +7,7 @@ add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS}) add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS}) add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS}) add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS}) +add_operator(search_group_padding extra SRCS search_group_padding_op.cc DEPS ${op_DEPS}) add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS}) add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS}) add_operator(scale_op basic SRCS scale_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index a3f9e45048ad0a01c9429c163bf38da876dd5f42..f4373ad8d851270c8e83690113b8d40d0b5e0433 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -740,6 +740,14 @@ struct SequencePoolParam { #endif }; +struct SearchGroupPaddingParam { + lite::Tensor* x{}; + lite::Tensor* out_emb_padding{}; + lite::Tensor* out_new{}; + lite::Tensor* out_padding{}; + int pad_id; +}; + struct SequenceReshapeParam { lite::Tensor* x{}; lite::Tensor* output{}; diff --git a/lite/operators/search_group_padding_op.cc b/lite/operators/search_group_padding_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2556468100bc75add8ab75b422371602283157a8 --- /dev/null +++ b/lite/operators/search_group_padding_op.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/search_group_padding_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SearchGroupPaddingOp::CheckShape() const { + CHECK_EQ(param_.x->dims().size(), 2) << "The rank of X(Input) should be 2."; + CHECK_EQ(param_.x->lod().empty(), false) + << "Input Tensor of X does not contain LoD information."; + CHECK_GE(param_.x->lod()[0].size(), 2) + << "The Input(X)'s lod info is corrupted."; + CHECK_EQ(param_.x->dims()[0], static_cast(param_.x->lod()[0].back())) + << "The Input(X)'s lod info mismatches the actual tensor shape."; + + return true; +} + +bool SearchGroupPaddingOp::InferShape() const { + std::vector x_dims = param_.x->dims().Vectorize(); + + param_.out_emb_padding->Resize({-1, x_dims[1]}); + param_.out_new->Resize({x_dims[0], 1}); + param_.out_padding->Resize({-1, 1}); + return true; +} + +bool SearchGroupPaddingOp::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto out_emb_padding = op_desc.Input("Out_emb_padding").front(); + auto out_new = op_desc.Input("Out_new").front(); + auto out_padding = op_desc.Input("Out_padding").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.out_emb_padding = + scope->FindVar(out_emb_padding)->GetMutable(); + param_.out_new = scope->FindVar(out_new)->GetMutable(); + param_.out_padding = scope->FindVar(out_padding)->GetMutable(); + param_.pad_id = op_desc.GetAttr("pad_id"); + + CHECK(param_.out_emb_padding) + << "Output(Out_emb_padding) of SearchGroupPadding Op should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(search_group_padding, + paddle::lite::operators::SearchGroupPaddingOp); diff --git a/lite/operators/search_group_padding_op.h b/lite/operators/search_group_padding_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a8e96c9697b5f7de70349efa1f8b378a47c3823c --- /dev/null +++ b/lite/operators/search_group_padding_op.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SearchGroupPaddingOp : public OpLite { + public: + SearchGroupPaddingOp() {} + explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "search_group_padding"; } + + private: + mutable SearchGroupPaddingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle