未验证 提交 80fb550c 编写于 作者: C cc 提交者: GitHub

Add crf_decoding op, test=develop (#3167)

上级 3b01a955
......@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})
#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/crf_decoding_compute.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void CrfDecodingCompute::Run() {
auto& param = Param<operators::CrfDecodingParam>();
auto* emission_weights = param.emission;
auto* transition_weights = param.transition;
auto* label = param.label;
auto* decoded_path = param.viterbi_path;
int64_t* path = decoded_path->mutable_data<int64_t>();
std::fill(path, path + decoded_path->numel(), 0);
if (param.length != nullptr) {
auto* length = param.length;
int64_t seq_num = length->numel();
const int64_t* length_data = length->data<int64_t>();
auto in_dims = emission_weights->dims();
Tensor emission_weights_tmp = *emission_weights;
emission_weights_tmp.Resize({in_dims[0] * in_dims[1], in_dims[2]});
decoded_path->Resize({in_dims[0] * in_dims[1], 1});
for (int64_t i = 0; i < seq_num; ++i) {
if (length_data[i] == 0) continue;
int64_t start_pos = i * in_dims[1];
int64_t end_pos = start_pos + length_data[i];
Tensor decoded_path_one_seq =
decoded_path->Slice<int64_t>(start_pos, end_pos);
Decode<float>(emission_weights_tmp.Slice<float>(start_pos, end_pos),
*transition_weights,
&decoded_path_one_seq);
}
if (label != nullptr) {
const int64_t* label_value = label->data<int64_t>();
for (int64_t i = 0; i < seq_num; ++i) {
for (int64_t j = 0; j < in_dims[1]; ++j) {
int64_t start_pos = i * in_dims[1];
if (j < length_data[i]) {
path[start_pos + j] =
label_value[start_pos + j] == path[start_pos + j] ? 1 : 0;
} else {
path[start_pos + j] = 0;
}
}
}
}
} else {
auto lod = emission_weights->lod();
CHECK_EQ(lod.size(), 1UL);
CHECK_GT(lod.size(), 0);
const size_t level = 0;
const size_t seq_num = lod[level].size() - 1;
for (size_t i = 0; i < seq_num; ++i) {
if (lod[level][i] == lod[level][i + 1]) continue;
int64_t start_pos = static_cast<int64_t>(lod[level][i]);
int64_t end_pos = static_cast<int64_t>(lod[level][i + 1]);
Tensor decoded_path_one_seq =
decoded_path->Slice<int64_t>(start_pos, end_pos);
Decode<float>(emission_weights->Slice<float>(start_pos, end_pos),
*transition_weights,
&decoded_path_one_seq);
}
if (label != nullptr) {
auto label_lod = label->lod();
CHECK_EQ(label_lod.size(), 1);
const int64_t* label_value = label->data<int64_t>();
int64_t numel = label->numel();
for (int64_t i = 0; i < numel; ++i) {
path[i] = label_value[i] == path[i] ? 1 : 0;
}
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(crf_decoding,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::CrfDecodingCompute,
def)
.BindInput("Emission", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Transition", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Label", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Length", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("ViterbiPath",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <limits>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T>
void Decode(const Tensor& emission_weights,
const Tensor& transition_weights,
Tensor* decoded_path) {
auto emission_dims = emission_weights.dims();
const int64_t seq_len = emission_dims[0];
const int64_t tag_num = emission_dims[1];
const T* x = emission_weights.data<T>();
const T* w = transition_weights.data<T>();
int64_t* path = decoded_path->mutable_data<int64_t>();
// alpha is a memo table. An element alpha(k, v) records the score of the
// best sequence of tags from position 1 to position k with v being the end
// tag.
Tensor alpha;
alpha.Resize(emission_dims);
T* alpha_value = alpha.mutable_data<T>();
Tensor track;
track.Resize(emission_dims);
int* track_value = track.mutable_data<int>();
const int state_trans_base_idx = 2;
for (int i = 0; i < tag_num; ++i) {
alpha_value[i] = w[i] + x[i];
}
for (int k = 1; k < seq_len; ++k) {
for (int i = 0; i < tag_num; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
T score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {
T score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i];
if (score > max_score) {
max_score = score;
max_i = i;
}
}
path[seq_len - 1] = max_i;
for (int k = seq_len - 1; k >= 1; --k) {
path[k - 1] = max_i = track_value[k * tag_num + max_i];
}
}
class CrfDecodingCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~CrfDecodingCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -104,6 +104,7 @@ add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DE
add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS})
add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/crf_decoding_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool CrfDecodingOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.emission);
CHECK_OR_FALSE(param_.transition);
CHECK_OR_FALSE(param_.viterbi_path);
auto emission_dims = param_.emission->dims();
if (param_.length == nullptr) {
CHECK_OR_FALSE(emission_dims.size() == 2);
} else {
CHECK_OR_FALSE(emission_dims.size() == 3);
}
CHECK_OR_FALSE(emission_dims[0] != 0);
auto transition_dims = param_.transition->dims();
CHECK_OR_FALSE(transition_dims.size() == 2);
CHECK_OR_FALSE(transition_dims[0] - 2 == transition_dims[1]);
if ((emission_dims[emission_dims.size() - 1] > 0 &&
transition_dims[transition_dims.size() - 1] > 0)) {
CHECK_OR_FALSE(emission_dims[emission_dims.size() - 1] ==
transition_dims[transition_dims.size() - 1]);
}
if (param_.label != nullptr) {
auto label_dims = param_.label->dims();
if (param_.length != nullptr) {
CHECK_OR_FALSE((label_dims.size() == 3UL && label_dims[2] == 1) ||
label_dims.size() == 2UL);
} else {
CHECK_OR_FALSE((label_dims.size() == 2UL && label_dims[1] == 1) ||
label_dims.size() == 1UL);
}
if (emission_dims[0] > 0 && label_dims[0] > 0) {
CHECK_OR_FALSE(emission_dims[0] == label_dims[0]);
}
}
return true;
}
bool CrfDecodingOpLite::InferShape() const {
auto emission_dims = param_.emission->dims();
if (param_.length == nullptr) {
param_.viterbi_path->Resize({emission_dims[0], 1});
} else {
param_.viterbi_path->Resize({emission_dims[0], emission_dims[1]});
}
param_.viterbi_path->set_lod(param_.emission->lod());
return true;
}
bool CrfDecodingOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
// inputs
param_.emission = scope->FindVar(op_desc.Input("Emission").front())
->GetMutable<lite::Tensor>();
param_.transition = scope->FindVar(op_desc.Input("Transition").front())
->GetMutable<lite::Tensor>();
if (op_desc.HasInput("Label") && op_desc.Input("Label").size() > 0) {
param_.label = scope->FindVar(op_desc.Input("Label").front())
->GetMutable<lite::Tensor>();
}
if (op_desc.HasInput("Length") && op_desc.Input("Length").size() > 0) {
param_.length = scope->FindVar(op_desc.Input("Length").front())
->GetMutable<lite::Tensor>();
}
// outputs
param_.viterbi_path = scope->FindVar(op_desc.Output("ViterbiPath").front())
->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(crf_decoding, paddle::lite::operators::CrfDecodingOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class CrfDecodingOpLite : public OpLite {
public:
CrfDecodingOpLite() {}
explicit CrfDecodingOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "crf_decoding"; }
private:
mutable CrfDecodingParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1159,6 +1159,14 @@ struct LstmParam {
std::string candidate_activation;
};
struct CrfDecodingParam {
lite::Tensor* emission{};
lite::Tensor* transition{};
lite::Tensor* label{};
lite::Tensor* length{};
lite::Tensor* viterbi_path{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -75,4 +75,5 @@ endif()
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class CrfDecodingComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string emission_ = "Emission";
std::string transition_ = "Transition";
std::string output_ = "ViterbiPath";
public:
CrfDecodingComputeTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize({5, 1});
LoD out_lod;
out_lod.push_back({0, 2, 5});
out->set_lod(out_lod);
std::vector<int64_t> data = {0, 1, 0, 2, 2};
auto* out_data = out->mutable_data<int64_t>();
for (int i = 0; i < data.size(); i++) {
out_data[i] = data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("crf_decoding");
op_desc->SetInput("Emission", {emission_});
op_desc->SetInput("Transition", {transition_});
op_desc->SetOutput("ViterbiPath", {output_});
}
void PrepareData() override {
std::vector<float> emission_data = {0.39293837,
-0.42772133,
-0.54629709,
0.10262954,
0.43893794,
-0.15378708,
0.9615284,
0.36965948,
-0.0381362,
-0.21576496,
-0.31364397,
0.45809941};
LoD lod;
lod.push_back({0, 2, 5});
SetCommonTensor(emission_, DDim({5, 3}), emission_data.data(), lod);
std::vector<float> transition_data = {0.2379954057320357,
-0.3175082695465,
-0.32454824385250747,
0.03155137384183837,
0.03182758709686606,
0.13440095855132106,
0.34943179407778957,
0.22445532486063524,
0.11102351067758287,
0.22244338257022156,
-0.1770410861468218,
-0.1382113443776859,
-0.2717367691210444,
-0.20628595361117064,
0.13097612385448776};
SetCommonTensor(transition_, DDim({5, 3}), transition_data.data());
}
};
TEST(CrfDecoding, arm_precision) {
LOG(INFO) << "test crf_decoding op";
#ifdef LITE_WITH_X86
Place place(TARGET(kHost));
std::unique_ptr<arena::TestCase> tester(
new CrfDecodingComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kHost));
std::unique_ptr<arena::TestCase> tester(
new CrfDecodingComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册