提交 26318544 编写于 作者: Z zhoukunsheng 提交者: XiaoguangHu

support Tensor input for chunk_eval op (#18226)

* test=develop
support Tensor input for chunk_eval op

* test=develop
fix testcase for chunk_eval op

* test=develop
fix typos in nn.py
上级 206c44e2
...@@ -105,7 +105,7 @@ paddle.fluid.layers.cos_sim (ArgSpec(args=['X', 'Y'], varargs=None, keywords=Non ...@@ -105,7 +105,7 @@ paddle.fluid.layers.cos_sim (ArgSpec(args=['X', 'Y'], varargs=None, keywords=Non
paddle.fluid.layers.cross_entropy (ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)), ('document', 'f43c659ca1749a3f0ff2231e6dfda07d')) paddle.fluid.layers.cross_entropy (ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)), ('document', 'f43c659ca1749a3f0ff2231e6dfda07d'))
paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6263dfdeb6c670fa0922c9cbc8fb1bf4')) paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6263dfdeb6c670fa0922c9cbc8fb1bf4'))
paddle.fluid.layers.square_error_cost (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'bbb9e708bab250359864fefbdf48e9d9')) paddle.fluid.layers.square_error_cost (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'bbb9e708bab250359864fefbdf48e9d9'))
paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)), ('document', '5aa25d023acea1fb49a0de56be86990b')) paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types', 'seq_length'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'b02844e0ad4bd713c5fe6802aa13219c'))
paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)), ('document', '3d8e8f3e0e1cf520156be37605e83ccd')) paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)), ('document', '3d8e8f3e0e1cf520156be37605e83ccd'))
paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '114c7fe6b0adfc6d6371122f9b9f506e')) paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '114c7fe6b0adfc6d6371122f9b9f506e'))
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '367293b5bada54136a91621078d38334')) paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '367293b5bada54136a91621078d38334'))
......
...@@ -48,6 +48,15 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,15 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(inference_dim == label_dim, PADDLE_ENFORCE(inference_dim == label_dim,
"Inference's shape must be the same as Label's shape."); "Inference's shape must be the same as Label's shape.");
bool use_padding = ctx->HasInput("SeqLength");
if (use_padding) {
PADDLE_ENFORCE(inference_dim.size() == 3,
"when SeqLength is provided, Inference should be of dim 3 "
"(batch, bucket, 1)");
auto seq_length_dim = ctx->GetInputDim("SeqLength");
PADDLE_ENFORCE(seq_length_dim.size() == 1, "seq_length should be rank 1");
}
ctx->SetOutputDim("Precision", {1}); ctx->SetOutputDim("Precision", {1});
ctx->SetOutputDim("Recall", {1}); ctx->SetOutputDim("Recall", {1});
ctx->SetOutputDim("F1-Score", {1}); ctx->SetOutputDim("F1-Score", {1});
...@@ -72,6 +81,10 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -72,6 +81,10 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
"Predictions from the network."); "Predictions from the network.");
AddInput("Label", AddInput("Label",
"(Tensor, default: Tensor<int64_t>). The true tag sequences."); "(Tensor, default: Tensor<int64_t>). The true tag sequences.");
AddInput("SeqLength",
"(Tensor, default: Tensor<int64_t>). The length of each sequence, "
"used when Inference and Label are Tensor type .")
.AsDispensable();
AddOutput("Precision", AddOutput("Precision",
"(float). The evaluated precision (called positive predictive " "(float). The evaluated precision (called positive predictive "
"value) of chunks on the given mini-batch."); "value) of chunks on the given mini-batch.");
......
...@@ -173,18 +173,41 @@ class ChunkEvalKernel : public framework::OpKernel<T> { ...@@ -173,18 +173,41 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
*num_correct_chunks_data = 0; *num_correct_chunks_data = 0;
auto lod = label->lod(); auto lod = label->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); bool use_padding = lod.empty();
PADDLE_ENFORCE(lod == inference->lod(), int num_sequences = 0;
"LoD must be same between Inference and Label.");
int num_sequences = lod[0].size() - 1; if (use_padding) {
for (int i = 0; i < num_sequences; ++i) { auto dim1 = inference->dims()[1];
int seq_length = lod[0][i + 1] - lod[0][i]; auto* seq_length_t = context.Input<Tensor>("SeqLength");
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length, auto* seq_length_data = seq_length_t->data<int64_t>();
&output_segments, &label_segments, num_infer_chunks_data, num_sequences = seq_length_t->dims()[0];
num_label_chunks_data, num_correct_chunks_data,
num_chunk_types, num_tag_types, other_chunk_type, tag_begin, for (int i = 0; i < num_sequences; ++i) {
tag_inside, tag_end, tag_single, excluded_chunk_types); int seq_length = seq_length_data[i];
EvalOneSeq(inference_data + i * dim1, label_data + i * dim1, seq_length,
&output_segments, &label_segments, num_infer_chunks_data,
num_label_chunks_data, num_correct_chunks_data,
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
tag_inside, tag_end, tag_single, excluded_chunk_types);
}
} else {
PADDLE_ENFORCE_EQ(lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE(lod == inference->lod(),
"LoD must be same between Inference and Label.");
num_sequences = lod[0].size() - 1;
for (int i = 0; i < num_sequences; ++i) {
int seq_length = lod[0][i + 1] - lod[0][i];
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i],
seq_length, &output_segments, &label_segments,
num_infer_chunks_data, num_label_chunks_data,
num_correct_chunks_data, num_chunk_types, num_tag_types,
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single,
excluded_chunk_types);
}
} }
*precision_data = !(*num_infer_chunks_data) *precision_data = !(*num_infer_chunks_data)
? 0 ? 0
: static_cast<T>(*num_correct_chunks_data) / : static_cast<T>(*num_correct_chunks_data) /
......
...@@ -1679,7 +1679,8 @@ def chunk_eval(input, ...@@ -1679,7 +1679,8 @@ def chunk_eval(input,
label, label,
chunk_scheme, chunk_scheme,
num_chunk_types, num_chunk_types,
excluded_chunk_types=None): excluded_chunk_types=None,
seq_length=None):
""" """
**Chunk Evaluator** **Chunk Evaluator**
...@@ -1751,6 +1752,7 @@ def chunk_eval(input, ...@@ -1751,6 +1752,7 @@ def chunk_eval(input,
chunk_scheme (str): ${chunk_scheme_comment} chunk_scheme (str): ${chunk_scheme_comment}
num_chunk_types (int): ${num_chunk_types_comment} num_chunk_types (int): ${num_chunk_types_comment}
excluded_chunk_types (list): ${excluded_chunk_types_comment} excluded_chunk_types (list): ${excluded_chunk_types_comment}
seq_length(Variable): 1-D Tensor specifying sequence length when input and label are Tensor type.
Returns: Returns:
tuple: tuple containing: precision, recall, f1_score, tuple: tuple containing: precision, recall, f1_score,
...@@ -1792,10 +1794,14 @@ def chunk_eval(input, ...@@ -1792,10 +1794,14 @@ def chunk_eval(input,
num_correct_chunks = helper.create_variable_for_type_inference( num_correct_chunks = helper.create_variable_for_type_inference(
dtype="int64") dtype="int64")
this_input = {"Inference": [input], "Label": [label]}
if seq_length:
this_input["SeqLength"] = [seq_length]
helper.append_op( helper.append_op(
type="chunk_eval", type="chunk_eval",
inputs={"Inference": [input], inputs=this_input,
"Label": [label]},
outputs={ outputs={
"Precision": [precision], "Precision": [precision],
"Recall": [recall], "Recall": [recall],
......
...@@ -150,7 +150,7 @@ class TestChunkEvalOp(OpTest): ...@@ -150,7 +150,7 @@ class TestChunkEvalOp(OpTest):
lod = [] lod = []
for i in range(len(starts) - 1): for i in range(len(starts) - 1):
lod.append(starts[i + 1] - starts[i]) lod.append(starts[i + 1] - starts[i])
self.inputs = {'Inference': (infer, [lod]), 'Label': (label, [lod])} self.set_input(infer, label, lod)
precision = float( precision = float(
self.num_correct_chunks self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0 ) / self.num_infer_chunks if self.num_infer_chunks else 0
...@@ -173,6 +173,9 @@ class TestChunkEvalOp(OpTest): ...@@ -173,6 +173,9 @@ class TestChunkEvalOp(OpTest):
[self.num_correct_chunks], dtype='int64') [self.num_correct_chunks], dtype='int64')
} }
def set_input(self, infer, label, lod):
self.inputs = {'Inference': (infer, [lod]), 'Label': (label, [lod])}
def setUp(self): def setUp(self):
self.op_type = 'chunk_eval' self.op_type = 'chunk_eval'
self.set_confs() self.set_confs()
...@@ -198,5 +201,33 @@ class TestChunkEvalOpWithExclude(TestChunkEvalOp): ...@@ -198,5 +201,33 @@ class TestChunkEvalOpWithExclude(TestChunkEvalOp):
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 15, 18, 20 self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 15, 18, 20
class TestChunkEvalOpWithTensorInput(TestChunkEvalOp):
def set_input(self, infer, label, lod):
max_len = np.max(lod)
pad_infer = []
pad_label = []
start = 0
for i in range(len(lod)):
end = lod[i] + start
pad_infer.append(
np.pad(infer[start:end], (0, max_len - lod[i]),
'constant',
constant_values=(-1, )))
pad_label.append(
np.pad(label[start:end], (0, max_len - lod[i]),
'constant',
constant_values=(-1, )))
start = end
pad_infer = np.expand_dims(np.array(pad_infer, dtype='int64'), 2)
pad_label = np.expand_dims(np.array(pad_label, dtype='int64'), 2)
lod = np.array(lod, dtype='int64')
self.inputs = {
'Inference': pad_infer,
'Label': pad_label,
'SeqLength': lod
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册