提交 bb9d68dc 编写于 作者: G guosheng

Add chunk_eval_op

上级 6604d7cd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/chunk_eval_op.h"
namespace paddle {
namespace operators {
class ChunkEvalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input(Inference) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Precision"),
"Output(Precision) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Recall"),
"Output(Recall) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("F1-Score"),
"Output(F1-Score) of ChunkEvalOp should not be null.");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE(inference_dim == label_dim,
"Inference's shape must be the same as Label's shape.");
ctx->SetOutputDim("Precision", {1});
ctx->SetOutputDim("Recall", {1});
ctx->SetOutputDim("F1-Score", {1});
}
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return framework::DataType::FP32;
}
};
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChunkEvalOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference",
"(Tensor, default: Tensor<int>) Predictions from the network.");
AddInput("Label", "(Tensor, default: Tensor<int>) Labels of the data.");
AddOutput(
"Precision",
"(float) The precision ratio of the predictions on current data.");
AddOutput("Recall",
"(float) The recall ratio of the predictions on current data.");
AddOutput("F1-Score",
"(float) The F1-Score of the predictions on current data.");
AddAttr<int>("num_chunk_types", "(int) The number of chunk type.");
AddAttr<std::string>("chunk_scheme",
"(string, default IOB) The label scheme.")
.SetDefault("IOB");
AddAttr<std::vector<int>>(
"excluded_chunk_types",
"(list<int>) A list<int> indicating chunk types not to be counted.")
.SetDefault(std::vector<int>{});
AddComment(R"DOC(
Chunk evaluator is used to evaluate segment labelling accuracy for a
sequence. It calculates precision, recall and F1 scores for the chunk detection.
To use chunk evaluator, several concepts need to be clarified firstly.
[Chunk type] is the type of the whole chunk and a chunk consists of one or several words. (For example in NER, ORG for organization name, PER for person name etc.)
[Tag type] indicates the position of a word in a chunk. (B for begin, I for inside, E for end, S for single)
We can name a label by combining tag type and chunk type. (ie. B-ORG for begining of an organization name)
The construction of label dictionary should obey the following rules:
- Use one of the listed labelling schemes. These schemes differ in ways indicating chunk boundry.
Scheme Description
plain Use the same label for the whole chunk.
IOB Two labels for chunk type X, B-X for chunk begining and I-X for chunk inside.
IOE Two labels for chunk type X, E-X for chunk ending and I-X for chunk inside.
IOBES Four labels for chunk type X, B-X for chunk begining, I-X for chunk inside, E-X for chunk end and S-X for single word chunk.
To make it clear, let's illustrate by an NER example.
Assuming that there are three named entity types including ORG, PER and LOC which are called 'chunk type' here,
if 'IOB' scheme were used, the label set will be extended to a set including B-ORG, I-ORG, B-PER, I-PER, B-LOC, I-LOC and O,
in which B-ORG for begining of ORG and I-ORG for inside of ORG.
Prefixes which are called 'tag type' here are added to chunk types and there are two tag types including B and I.
Of course, the training data should be labeled accordingly.
- Mapping is done correctly by the listed equations and assigning protocol.
The following table are equations to extract tag type and chunk type from a label.
tagType = label % numTagType
chunkType = label / numTagType
otherChunkType = numChunkTypes
The following table shows the mapping rule between tagType and tag type in each scheme.
Scheme Begin Inside End Single
plain 0 - - -
IOB 0 1 - -
IOE - 0 1 -
IOBES 0 1 2 3
Continue the NER example, and the label dict should look like this to satify above equations:
B-ORG 0
I-ORG 1
B-PER 2
I-PER 3
B-LOC 4
I-LOC 5
O 6
In this example, chunkType has three values: 0 for ORG, 1 for PER, 2 for LOC, because the scheme is
"IOB" so tagType has two values: 0 for B and 1 for I.
Here we will use I-LOC to explain the above mapping rules in detail.
For I-LOC, the label id is 5, so we can get tagType=1 and chunkType=2, which means I-LOC is a part of NER chunk LOC
and the tag is I.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(chunk_eval, ops::ChunkEvalOp,
ops::ChunkEvalOpMaker);
REGISTER_OP_CPU_KERNEL(chunk_eval,
ops::ChunkEvalKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <set>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class ChunkEvalKernel : public framework::OpKernel<T> {
public:
struct Segment {
int begin;
int end;
int type;
bool operator==(const Segment& y) const {
return begin == y.begin && end == y.end && type == y.type;
}
};
void GetSegments(const int* label, int length, std::vector<Segment>& segments,
int num_chunk_types, int num_tag_types, int other_chunk_type,
int tag_begin, int tag_inside, int tag_end,
int tag_single) const {
segments.clear();
segments.reserve(length);
int chunk_start = 0;
bool in_chunk = false;
int tag = -1;
int type = other_chunk_type;
for (int i = 0; i < length; ++i) {
int prev_tag = tag;
int prev_type = type;
PADDLE_ENFORCE_LE(label[i], num_chunk_types * num_tag_types);
tag = label[i] % num_tag_types;
type = label[i] / num_tag_types;
if (in_chunk && ChunkEnd(prev_tag, prev_type, tag, type, other_chunk_type,
tag_begin, tag_inside, tag_end, tag_single)) {
Segment segment{
chunk_start, // begin
i - 1, // end
prev_type,
};
segments.push_back(segment);
in_chunk = false;
}
if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type,
tag_begin, tag_inside, tag_end, tag_single)) {
chunk_start = i;
in_chunk = true;
}
}
if (in_chunk) {
Segment segment{
chunk_start, // begin
length - 1, // end
type,
};
segments.push_back(segment);
}
}
bool ChunkEnd(int prev_tag, int prev_type, int tag, int type,
int other_chunk_type, int tag_begin, int tag_inside,
int tag_end, int tag_single) const {
if (prev_type == other_chunk_type) return false;
if (type == other_chunk_type) return true;
if (type != prev_type) return true;
if (prev_tag == tag_begin) return tag == tag_begin || tag == tag_single;
if (prev_tag == tag_inside) return tag == tag_begin || tag == tag_single;
if (prev_tag == tag_end) return true;
if (prev_tag == tag_single) return true;
return false;
}
bool ChunkBegin(int prev_tag, int prev_type, int tag, int type,
int other_chunk_type, int tag_begin, int tag_inside,
int tag_end, int tag_single) const {
if (prev_type == other_chunk_type) return type != other_chunk_type;
if (type == other_chunk_type) return false;
if (type != prev_type) return true;
if (tag == tag_begin) return true;
if (tag == tag_inside) return prev_tag == tag_end || prev_tag == tag_single;
if (tag == tag_end) return prev_tag == tag_end || prev_tag == tag_single;
if (tag == tag_single) return true;
return false;
}
void Compute(const framework::ExecutionContext& context) const override {
// initialize to parse configurations
int num_chunk_types, num_tag_types;
int other_chunk_type;
int tag_begin, tag_inside, tag_end, tag_single;
std::vector<Segment> label_segments;
std::vector<Segment> output_segments;
std::set<int> excluded_chunk_types;
int64_t num_output_segments = 0;
int64_t num_label_segments = 0;
int64_t num_correct = 0;
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
num_tag_types = 2;
tag_begin = 0;
tag_inside = 1;
tag_end = -1;
tag_single = -1;
} else if (context.Attr<std::string>("chunk_scheme") == "IOE") {
num_tag_types = 2;
tag_begin = -1;
tag_inside = 0;
tag_end = 1;
tag_single = -1;
} else if (context.Attr<std::string>("chunk_scheme") == "IOBES") {
num_tag_types = 4;
tag_begin = 0;
tag_inside = 1;
tag_end = 2;
tag_single = 3;
} else if (context.Attr<std::string>("chunk_scheme") == "plain") {
num_tag_types = 1;
tag_begin = -1;
tag_inside = -1;
tag_end = -1;
tag_single = -1;
} else {
PADDLE_THROW("Unknown chunk scheme.");
}
other_chunk_type = num_chunk_types = context.Attr<int>("num_chunk_types");
excluded_chunk_types.insert(
context.Attr<std::vector<int>>("excluded_chunk_types").begin(),
context.Attr<std::vector<int>>("excluded_chunk_types").end());
auto* inference = context.Input<LoDTensor>("Inference");
auto* label = context.Input<LoDTensor>("Label");
auto* precision = context.Output<Tensor>("Precision");
auto* recall = context.Output<Tensor>("Recall");
auto* f1 = context.Output<Tensor>("F1-Score");
const int* inference_data = inference->data<int>();
const int* label_data = label->data<int>();
T* precision_data = precision->mutable_data<T>(context.GetPlace());
T* racall_data = recall->mutable_data<T>(context.GetPlace());
T* f1_data = f1->mutable_data<T>(context.GetPlace());
auto lod = label->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE(lod == inference->lod(),
"LoD must be same between Inference and Label.");
int num_sequences = lod[0].size() - 1;
for (int i = 0; i < num_sequences; ++i) {
int seq_length = lod[0][i + 1] - lod[0][i];
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
output_segments, label_segments, num_output_segments,
num_label_segments, num_correct, num_chunk_types,
num_tag_types, other_chunk_type, tag_begin, tag_inside,
tag_end, tag_single, excluded_chunk_types);
}
*precision_data =
!num_output_segments ? 0 : (T)num_correct / num_output_segments;
*racall_data =
!num_label_segments ? 0 : (T)num_correct / num_label_segments;
*f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) /
((*precision_data) + (*racall_data));
}
void EvalOneSeq(const int* output, const int* label, int length,
std::vector<Segment>& output_segments,
std::vector<Segment>& label_segments,
int64_t& num_output_segments, int64_t& num_label_segments,
int64_t& num_correct, int num_chunk_types, int num_tag_types,
int other_chunk_type, int tag_begin, int tag_inside,
int tag_end, int tag_single,
const std::set<int>& excluded_chunk_types) const {
GetSegments(output, length, output_segments, num_chunk_types, num_tag_types,
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
GetSegments(label, length, label_segments, num_chunk_types, num_tag_types,
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
size_t i = 0, j = 0;
while (i < output_segments.size() && j < label_segments.size()) {
if (output_segments[i] == label_segments[j] &&
excluded_chunk_types.count(output_segments[i].type) != 1) {
++num_correct;
}
if (output_segments[i].end < label_segments[j].end) {
++i;
} else if (output_segments[i].end > label_segments[j].end) {
++j;
} else {
++i;
++j;
}
}
for (auto& segment : label_segments) {
if (excluded_chunk_types.count(segment.type) != 1) ++num_label_segments;
}
for (auto& segment : output_segments) {
if (excluded_chunk_types.count(segment.type) != 1) ++num_output_segments;
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class Segments(object):
def __init__(self, chunk_type, start_idx, end_idx):
self.chunk_type = chunk_type
self.start_idx = start_idx
self.end_idx = end_idx
def __str__(self):
return '(Segments: %s, %s, %s)' % (self.chunk_type, self.start_idx,
self.end_idx)
__repr__ = __str__
class TestChunkEvalOp(OpTest):
num_sequences = 5
batch_size = 50
def parse_scheme(self):
if self.scheme == 'IOB':
self.num_tag_types = 2
elif self.scheme == 'IOE':
self.num_tag_types = 2
def fill_with_chunks(self, data, chunks):
for chunk in chunks:
if self.scheme == 'IOB':
data[chunk.start_idx] = chunk.chunk_type * self.num_tag_types
data[chunk.start_idx + 1:
chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
self.num_tag_types - 1)
data[chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
self.num_tag_types - 1
) if chunk.start_idx < chunk.end_idx else data[chunk.start_idx]
elif self.scheme == 'IOE':
data[chunk.start_idx:
chunk.end_idx] = chunk.chunk_type * self.num_tag_types
data[chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
self.num_tag_types - 1)
def rand_chunks(self, starts, num_chunks):
if num_chunks < 0:
num_chunks = np.random.randint(starts[-1])
chunks = []
# generate chunk beginnings
chunk_begins = sorted(
np.random.choice(
range(starts[-1]), num_chunks, replace=False))
seq_chunk_begins = []
begin_idx = 0
# divide chunks into sequences
for i in range(len(starts) - 1):
tmp_chunk_begins = []
while begin_idx < len(chunk_begins) and chunk_begins[
begin_idx] < starts[i + 1]:
tmp_chunk_begins.append(chunk_begins[begin_idx])
begin_idx += 1
seq_chunk_begins.append(tmp_chunk_begins)
# generate chunk ends
chunk_ends = []
for i in range(len(seq_chunk_begins)):
for j in range(len(seq_chunk_begins[i])):
low = seq_chunk_begins[i][j]
high = seq_chunk_begins[i][j + 1] if j < len(seq_chunk_begins[
i]) - 1 else starts[i + 1]
chunk_ends.append(np.random.randint(low, high))
# generate chunks
for chunk_pos in zip(chunk_begins, chunk_ends):
chunk_type = np.random.randint(self.num_chunk_types)
chunks.append(Segments(chunk_type, *chunk_pos))
return chunks
def gen_chunks(self, infer, label, starts):
chunks = self.rand_chunks(starts,
self.num_infer_chunks + self.num_label_chunks
- self.num_correct_chunks)
correct_chunks = np.random.choice(
range(len(chunks)), self.num_correct_chunks, replace=False)
infer_chunks = np.random.choice(
[x for x in range(len(chunks)) if x not in correct_chunks],
self.num_infer_chunks - self.num_correct_chunks,
replace=False)
infer_chunks = sorted(correct_chunks.tolist() + infer_chunks.tolist())
label_chunks = np.random.choice(
[x for x in range(len(chunks)) if x not in infer_chunks],
self.num_label_chunks - self.num_correct_chunks,
replace=False)
label_chunks = sorted(correct_chunks.tolist() + label_chunks.tolist())
self.fill_with_chunks(infer, [chunks[idx] for idx in infer_chunks])
self.fill_with_chunks(label, [chunks[idx] for idx in label_chunks])
# exclude types in excluded_chunk_types
if len(self.excluded_chunk_types) > 0:
for idx in correct_chunks:
if chunks[idx].chunk_type in self.excluded_chunk_types:
self.num_correct_chunks -= 1
for idx in infer_chunks:
if chunks[idx].chunk_type in self.excluded_chunk_types:
self.num_infer_chunks -= 1
for idx in label_chunks:
if chunks[idx].chunk_type in self.excluded_chunk_types:
self.num_label_chunks -= 1
return self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks
def set_confs(self):
# Use the IOB scheme and labels with 2 chunk types
self.scheme = 'IOB'
self.num_chunk_types = 2
self.excluded_chunk_types = []
self.other_chunk_type = self.num_chunk_types
self.attrs = {
'num_chunk_types': self.num_chunk_types,
'chunk_scheme': self.scheme,
'excluded_chunk_types': self.excluded_chunk_types
}
self.parse_scheme()
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 4, 5, 9
def set_data(self):
infer = np.zeros((self.batch_size, )).astype("int32")
infer.fill(self.num_chunk_types * self.num_tag_types)
label = np.copy(infer)
starts = np.random.choice(
range(1, self.batch_size), self.num_sequences - 1,
replace=False).tolist()
starts.extend([0, self.batch_size])
starts = sorted(starts)
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = self.gen_chunks(
infer, label, starts)
self.inputs = {
'Inference': (infer, [starts]),
'Label': (label, [starts])
}
precision = float(
self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0
recall = float(self.num_correct_chunks
) / self.num_label_chunks if self.num_label_chunks else 0
f1 = float(2 * precision * recall) / (
precision + recall) if self.num_correct_chunks else 0
self.outputs = {
'Precision': [precision],
'Recall': [recall],
'F1-Score': [f1]
}
def setUp(self):
self.op_type = 'chunk_eval'
self.set_confs()
self.set_data()
def test_check_output(self):
self.check_output()
class TestChunkEvalOpWithExclude(TestChunkEvalOp):
def set_confs(self):
# Use the IOE scheme and labels with 3 chunk types
self.scheme = 'IOE'
self.num_chunk_types = 3
self.excluded_chunk_types = [1]
self.other_chunk_type = self.num_chunk_types
self.attrs = {
'num_chunk_types': self.num_chunk_types,
'chunk_scheme': self.scheme,
'excluded_chunk_types': self.excluded_chunk_types
}
self.parse_scheme()
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 15, 18, 20
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册