未验证 提交 161128ba 编写于 作者: Q Qiao Longfei 提交者: GitHub

add chunk eval layer (#6296)

* add crf_decoding layer

* fix some typo

* init trunk_evaluator

* add trunk_evaluator layer

* update chunk_eval_op and test, change int32 to int64

* fix a numeric problem

* change layers.trunk_evaluator to layers.trunk_eval

* fix typo

* add precision_val
上级 1a8f20c6
......@@ -58,9 +58,10 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference",
"(Tensor, default: Tensor<int>). Predictions from the network.");
"(Tensor, default: Tensor<int64_t>). "
"Predictions from the network.");
AddInput("Label",
"(Tensor, default: Tensor<int>). The true tag sequences.");
"(Tensor, default: Tensor<int64_t>). The true tag sequences.");
AddOutput("Precision",
"(float). The evaluated precision (called positive predictive "
"value) of chunks on the given mini-batch.");
......@@ -84,7 +85,7 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(std::vector<int>{});
AddComment(R"DOC(
For some basics of chunking, please refer to
‘Chunking with Support Vector Mechines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>’.
‘Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>’.
CheckEvalOp computes the precision, recall, and F1-score of chunk detection,
......@@ -97,7 +98,7 @@ Here is a NER example of labeling for these tagging schemes:
IOE: I-PER E-PER O O I-ORG I-ORG I-ORG E-ORG O E-LOC
IOBES: B-PER E-PER O O I-ORG I-ORG I-ORG E-ORG O S-LOC
There are three chunk types(named entity types) including PER(person), ORG(orgnazation)
There are three chunk types(named entity types) including PER(person), ORG(organization)
and LOC(LOCATION), and we can see that the labels have the form <tag type>-<chunk type>.
Since the calculations actually use label ids rather than labels, extra attention
......
......@@ -35,10 +35,10 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
}
};
void GetSegments(const int* label, int length, std::vector<Segment>& segments,
int num_chunk_types, int num_tag_types, int other_chunk_type,
int tag_begin, int tag_inside, int tag_end,
int tag_single) const {
void GetSegments(const int64_t* label, int length,
std::vector<Segment>& segments, int num_chunk_types,
int num_tag_types, int other_chunk_type, int tag_begin,
int tag_inside, int tag_end, int tag_single) const {
segments.clear();
segments.reserve(length);
int chunk_start = 0;
......@@ -152,8 +152,8 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
auto* recall = context.Output<Tensor>("Recall");
auto* f1 = context.Output<Tensor>("F1-Score");
const int* inference_data = inference->data<int>();
const int* label_data = label->data<int>();
const int64_t* inference_data = inference->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
T* precision_data = precision->mutable_data<T>(context.GetPlace());
T* racall_data = recall->mutable_data<T>(context.GetPlace());
T* f1_data = f1->mutable_data<T>(context.GetPlace());
......@@ -179,7 +179,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
((*precision_data) + (*racall_data));
}
void EvalOneSeq(const int* output, const int* label, int length,
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
std::vector<Segment>& output_segments,
std::vector<Segment>& label_segments,
int64_t& num_output_segments, int64_t& num_label_segments,
......
......@@ -632,6 +632,40 @@ def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
return acc_out
def chunk_eval(input,
label,
chunk_scheme,
num_chunk_types,
excluded_chunk_types=None,
**kwargs):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("chunk_eval", **kwargs)
# prepare output
precision = helper.create_tmp_variable(dtype="float32")
recall = helper.create_tmp_variable(dtype="float32")
f1_score = helper.create_tmp_variable(dtype="float32")
helper.append_op(
type="chunk_eval",
inputs={"Inference": [input],
"Label": [label]},
outputs={
"Precision": [precision],
"Recall": [recall],
"F1-Score": [f1_score]
},
attrs={
"num_chunk_types": num_chunk_types,
'chunk_scheme': chunk_scheme,
'excluded_chunk_types': excluded_chunk_types or []
})
return precision, recall, f1_score
def sequence_conv(input,
num_filters,
filter_size=3,
......
import math
import numpy as np
import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05
......@@ -146,9 +148,13 @@ def main():
# TODO(qiao)
# add dependency track and move this config before optimizer
crf_decode = fluid.layers.crf_decoding(
input=feature_out,
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
precision, recall, f1_score = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
param_attr=fluid.ParamAttr(name='crfw'))
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
train_data = paddle.batch(
paddle.reader.shuffle(
......@@ -173,10 +179,17 @@ def main():
for data in train_data():
outs = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
fetch_list=[avg_cost, precision, recall, f1_score])
avg_cost_val = np.array(outs[0])
precision_val = np.array(outs[1])
recall_val = np.array(outs[2])
f1_score_val = np.array(outs[3])
if batch_id % 10 == 0:
print("avg_cost=" + str(avg_cost_val))
print("precision_val=" + str(precision_val))
print("recall_val:" + str(recall_val))
print("f1_score_val:" + str(f1_score_val))
# exit early for CI
exit(0)
......
......@@ -120,7 +120,7 @@ class TestChunkEvalOp(OpTest):
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 4, 5, 9
def set_data(self):
infer = np.zeros((self.batch_size, )).astype('int32')
infer = np.zeros((self.batch_size, )).astype('int64')
infer.fill(self.num_chunk_types * self.num_tag_types)
label = np.copy(infer)
starts = np.random.choice(
......
......@@ -130,6 +130,7 @@ class TestBook(unittest.TestCase):
def test_linear_chain_crf(self):
program = Program()
with program_guard(program, startup_program=Program()):
label_dict_len = 10
images = layers.data(name='pixel', shape=[784], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='int32')
hidden = layers.fc(input=images, size=128)
......@@ -137,6 +138,11 @@ class TestBook(unittest.TestCase):
input=hidden, label=label, param_attr=ParamAttr(name="crfw"))
crf_decode = layers.crf_decoding(
input=hidden, param_attr=ParamAttr(name="crfw"))
layers.chunk_eval(
input=crf_decode,
label=label,
chunk_scheme="IOB",
num_chunk_types=(label_dict_len - 1) / 2)
self.assertNotEqual(crf, None)
self.assertNotEqual(crf_decode, None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册