From 80fb550c5c4d94204b33a70c518c0879b93569d2 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 16 Mar 2020 09:45:56 +0800 Subject: [PATCH] Add crf_decoding op, test=develop (#3167) --- lite/kernels/host/CMakeLists.txt | 1 + lite/kernels/host/crf_decoding_compute.cc | 116 ++++++++++++++++++ lite/kernels/host/crf_decoding_compute.h | 95 ++++++++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/crf_decoding_op.cc | 100 +++++++++++++++ lite/operators/crf_decoding_op.h | 48 ++++++++ lite/operators/op_params.h | 8 ++ lite/tests/kernels/CMakeLists.txt | 1 + .../kernels/crf_decoding_compute_test.cc | 112 +++++++++++++++++ 9 files changed, 482 insertions(+) create mode 100644 lite/kernels/host/crf_decoding_compute.cc create mode 100644 lite/kernels/host/crf_decoding_compute.h create mode 100644 lite/operators/crf_decoding_op.cc create mode 100644 lite/operators/crf_decoding_op.h create mode 100644 lite/tests/kernels/crf_decoding_compute_test.cc diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 428cc213ce..f337e518ab 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps}) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/kernels/host/crf_decoding_compute.cc b/lite/kernels/host/crf_decoding_compute.cc new file mode 100644 index 0000000000..09bb41de63 --- /dev/null +++ b/lite/kernels/host/crf_decoding_compute.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/crf_decoding_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void CrfDecodingCompute::Run() { + auto& param = Param(); + auto* emission_weights = param.emission; + auto* transition_weights = param.transition; + auto* label = param.label; + auto* decoded_path = param.viterbi_path; + + int64_t* path = decoded_path->mutable_data(); + std::fill(path, path + decoded_path->numel(), 0); + + if (param.length != nullptr) { + auto* length = param.length; + int64_t seq_num = length->numel(); + const int64_t* length_data = length->data(); + auto in_dims = emission_weights->dims(); + + Tensor emission_weights_tmp = *emission_weights; + emission_weights_tmp.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + decoded_path->Resize({in_dims[0] * in_dims[1], 1}); + for (int64_t i = 0; i < seq_num; ++i) { + if (length_data[i] == 0) continue; + int64_t start_pos = i * in_dims[1]; + int64_t end_pos = start_pos + length_data[i]; + Tensor decoded_path_one_seq = + decoded_path->Slice(start_pos, end_pos); + Decode(emission_weights_tmp.Slice(start_pos, end_pos), + *transition_weights, + &decoded_path_one_seq); + } + if (label != nullptr) { + const int64_t* label_value = label->data(); + for (int64_t i = 0; i < seq_num; ++i) { + for (int64_t j = 0; j < in_dims[1]; ++j) { + int64_t start_pos = i * in_dims[1]; + if (j < length_data[i]) { + path[start_pos + j] = + label_value[start_pos + j] == path[start_pos + j] ? 1 : 0; + } else { + path[start_pos + j] = 0; + } + } + } + } + } else { + auto lod = emission_weights->lod(); + CHECK_EQ(lod.size(), 1UL); + CHECK_GT(lod.size(), 0); + const size_t level = 0; + const size_t seq_num = lod[level].size() - 1; + + for (size_t i = 0; i < seq_num; ++i) { + if (lod[level][i] == lod[level][i + 1]) continue; + int64_t start_pos = static_cast(lod[level][i]); + int64_t end_pos = static_cast(lod[level][i + 1]); + Tensor decoded_path_one_seq = + decoded_path->Slice(start_pos, end_pos); + Decode(emission_weights->Slice(start_pos, end_pos), + *transition_weights, + &decoded_path_one_seq); + } + if (label != nullptr) { + auto label_lod = label->lod(); + CHECK_EQ(label_lod.size(), 1); + const int64_t* label_value = label->data(); + int64_t numel = label->numel(); + for (int64_t i = 0; i < numel; ++i) { + path[i] = label_value[i] == path[i] ? 1 : 0; + } + } + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(crf_decoding, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::CrfDecodingCompute, + def) + .BindInput("Emission", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Transition", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Label", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Length", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("ViterbiPath", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/host/crf_decoding_compute.h b/lite/kernels/host/crf_decoding_compute.h new file mode 100644 index 0000000000..dd0cb85000 --- /dev/null +++ b/lite/kernels/host/crf_decoding_compute.h @@ -0,0 +1,95 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +void Decode(const Tensor& emission_weights, + const Tensor& transition_weights, + Tensor* decoded_path) { + auto emission_dims = emission_weights.dims(); + const int64_t seq_len = emission_dims[0]; + const int64_t tag_num = emission_dims[1]; + const T* x = emission_weights.data(); + const T* w = transition_weights.data(); + int64_t* path = decoded_path->mutable_data(); + + // alpha is a memo table. An element alpha(k, v) records the score of the + // best sequence of tags from position 1 to position k with v being the end + // tag. + Tensor alpha; + alpha.Resize(emission_dims); + T* alpha_value = alpha.mutable_data(); + Tensor track; + track.Resize(emission_dims); + int* track_value = track.mutable_data(); + + const int state_trans_base_idx = 2; + for (int i = 0; i < tag_num; ++i) { + alpha_value[i] = w[i] + x[i]; + } + + for (int k = 1; k < seq_len; ++k) { + for (int i = 0; i < tag_num; ++i) { + T max_score = -std::numeric_limits::max(); + int max_j = 0; + for (size_t j = 0; j < tag_num; ++j) { + T score = alpha_value[(k - 1) * tag_num + j] + + w[(j + state_trans_base_idx) * tag_num + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; + track_value[k * tag_num + i] = max_j; + } + } + + T max_score = -std::numeric_limits::max(); + int max_i = 0; + for (size_t i = 0; i < tag_num; ++i) { + T score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i]; + if (score > max_score) { + max_score = score; + max_i = i; + } + } + path[seq_len - 1] = max_i; + for (int k = seq_len - 1; k >= 1; --k) { + path[k - 1] = max_i = track_value[k * tag_num + max_i]; + } +} + +class CrfDecodingCompute : public KernelLite { + public: + void Run() override; + + virtual ~CrfDecodingCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 36d69e68b5..046b47e675 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -104,6 +104,7 @@ add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DE add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS}) add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS}) add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS}) +add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/crf_decoding_op.cc b/lite/operators/crf_decoding_op.cc new file mode 100644 index 0000000000..1b0a27ab4a --- /dev/null +++ b/lite/operators/crf_decoding_op.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/crf_decoding_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool CrfDecodingOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.emission); + CHECK_OR_FALSE(param_.transition); + CHECK_OR_FALSE(param_.viterbi_path); + + auto emission_dims = param_.emission->dims(); + if (param_.length == nullptr) { + CHECK_OR_FALSE(emission_dims.size() == 2); + } else { + CHECK_OR_FALSE(emission_dims.size() == 3); + } + CHECK_OR_FALSE(emission_dims[0] != 0); + + auto transition_dims = param_.transition->dims(); + CHECK_OR_FALSE(transition_dims.size() == 2); + CHECK_OR_FALSE(transition_dims[0] - 2 == transition_dims[1]); + + if ((emission_dims[emission_dims.size() - 1] > 0 && + transition_dims[transition_dims.size() - 1] > 0)) { + CHECK_OR_FALSE(emission_dims[emission_dims.size() - 1] == + transition_dims[transition_dims.size() - 1]); + } + + if (param_.label != nullptr) { + auto label_dims = param_.label->dims(); + if (param_.length != nullptr) { + CHECK_OR_FALSE((label_dims.size() == 3UL && label_dims[2] == 1) || + label_dims.size() == 2UL); + } else { + CHECK_OR_FALSE((label_dims.size() == 2UL && label_dims[1] == 1) || + label_dims.size() == 1UL); + } + if (emission_dims[0] > 0 && label_dims[0] > 0) { + CHECK_OR_FALSE(emission_dims[0] == label_dims[0]); + } + } + return true; +} + +bool CrfDecodingOpLite::InferShape() const { + auto emission_dims = param_.emission->dims(); + if (param_.length == nullptr) { + param_.viterbi_path->Resize({emission_dims[0], 1}); + } else { + param_.viterbi_path->Resize({emission_dims[0], emission_dims[1]}); + } + param_.viterbi_path->set_lod(param_.emission->lod()); + return true; +} + +bool CrfDecodingOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + // inputs + param_.emission = scope->FindVar(op_desc.Input("Emission").front()) + ->GetMutable(); + param_.transition = scope->FindVar(op_desc.Input("Transition").front()) + ->GetMutable(); + if (op_desc.HasInput("Label") && op_desc.Input("Label").size() > 0) { + param_.label = scope->FindVar(op_desc.Input("Label").front()) + ->GetMutable(); + } + if (op_desc.HasInput("Length") && op_desc.Input("Length").size() > 0) { + param_.length = scope->FindVar(op_desc.Input("Length").front()) + ->GetMutable(); + } + + // outputs + param_.viterbi_path = scope->FindVar(op_desc.Output("ViterbiPath").front()) + ->GetMutable(); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(crf_decoding, paddle::lite::operators::CrfDecodingOpLite); diff --git a/lite/operators/crf_decoding_op.h b/lite/operators/crf_decoding_op.h new file mode 100644 index 0000000000..6aaf338ec2 --- /dev/null +++ b/lite/operators/crf_decoding_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class CrfDecodingOpLite : public OpLite { + public: + CrfDecodingOpLite() {} + + explicit CrfDecodingOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "crf_decoding"; } + + private: + mutable CrfDecodingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 0ca7e3d2a8..6d18f1bf34 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1159,6 +1159,14 @@ struct LstmParam { std::string candidate_activation; }; +struct CrfDecodingParam { + lite::Tensor* emission{}; + lite::Tensor* transition{}; + lite::Tensor* label{}; + lite::Tensor* length{}; + lite::Tensor* viterbi_path{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index aab078ccdc..de10403761 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -75,4 +75,5 @@ endif() lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/kernels/crf_decoding_compute_test.cc b/lite/tests/kernels/crf_decoding_compute_test.cc new file mode 100644 index 0000000000..7eaed73505 --- /dev/null +++ b/lite/tests/kernels/crf_decoding_compute_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class CrfDecodingComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string emission_ = "Emission"; + std::string transition_ = "Transition"; + std::string output_ = "ViterbiPath"; + + public: + CrfDecodingComputeTester(const Place& place, const std::string& alias) + : TestCase(place, alias) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize({5, 1}); + LoD out_lod; + out_lod.push_back({0, 2, 5}); + out->set_lod(out_lod); + + std::vector data = {0, 1, 0, 2, 2}; + auto* out_data = out->mutable_data(); + for (int i = 0; i < data.size(); i++) { + out_data[i] = data[i]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("crf_decoding"); + op_desc->SetInput("Emission", {emission_}); + op_desc->SetInput("Transition", {transition_}); + op_desc->SetOutput("ViterbiPath", {output_}); + } + + void PrepareData() override { + std::vector emission_data = {0.39293837, + -0.42772133, + -0.54629709, + 0.10262954, + 0.43893794, + -0.15378708, + 0.9615284, + 0.36965948, + -0.0381362, + -0.21576496, + -0.31364397, + 0.45809941}; + LoD lod; + lod.push_back({0, 2, 5}); + SetCommonTensor(emission_, DDim({5, 3}), emission_data.data(), lod); + + std::vector transition_data = {0.2379954057320357, + -0.3175082695465, + -0.32454824385250747, + 0.03155137384183837, + 0.03182758709686606, + 0.13440095855132106, + 0.34943179407778957, + 0.22445532486063524, + 0.11102351067758287, + 0.22244338257022156, + -0.1770410861468218, + -0.1382113443776859, + -0.2717367691210444, + -0.20628595361117064, + 0.13097612385448776}; + SetCommonTensor(transition_, DDim({5, 3}), transition_data.data()); + } +}; + +TEST(CrfDecoding, arm_precision) { + LOG(INFO) << "test crf_decoding op"; +#ifdef LITE_WITH_X86 + Place place(TARGET(kHost)); + std::unique_ptr tester( + new CrfDecodingComputeTester(place, "def")); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +#endif + +#ifdef LITE_WITH_ARM + Place place(TARGET(kHost)); + std::unique_ptr tester( + new CrfDecodingComputeTester(place, "def")); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +#endif +} + +} // namespace lite +} // namespace paddle -- GitLab