提交 8fe3a3aa 编写于 作者: P Peng Li

Add excluded_chunk_types to ChunkEvaluator

The chunks of types in excluded_chunk_types will not be counted in
ChunkEvaluator. This is useful for tasks such as SRL, in which chunks of
type V (verb) will not be taken into account in evaluation.
上级 8a42a549
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <set>
#include <vector> #include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
...@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator { ...@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std::vector<Segment> labelSegments_; std::vector<Segment> labelSegments_;
std::vector<Segment> outputSegments_; std::vector<Segment> outputSegments_;
std::set<int> excludedChunkTypes_;
public: public:
virtual void init(const EvaluatorConfig& config) { virtual void init(const EvaluatorConfig& config) {
...@@ -105,6 +107,10 @@ public: ...@@ -105,6 +107,10 @@ public:
} }
CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config"; CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config";
otherChunkType_ = numChunkTypes_ = config.num_chunk_types(); otherChunkType_ = numChunkTypes_ = config.num_chunk_types();
// the chunks of types in excludedChunkTypes_ will not be counted
auto& tmp = config.excluded_chunk_types();
excludedChunkTypes_.insert(tmp.begin(), tmp.end());
} }
virtual void start() { virtual void start() {
...@@ -157,7 +163,8 @@ public: ...@@ -157,7 +163,8 @@ public:
size_t i = 0, j = 0; size_t i = 0, j = 0;
while (i < outputSegments_.size() && j < labelSegments_.size()) { while (i < outputSegments_.size() && j < labelSegments_.size()) {
if (outputSegments_[i] == labelSegments_[j]) { if (outputSegments_[i] == labelSegments_[j]) {
++numCorrect_; if (excludedChunkTypes_.count(outputSegments_[i].type) != 1)
++numCorrect_;
} }
if (outputSegments_[i].end < labelSegments_[j].end) { if (outputSegments_[i].end < labelSegments_[j].end) {
++i; ++i;
...@@ -168,8 +175,12 @@ public: ...@@ -168,8 +175,12 @@ public:
++j; ++j;
} }
} }
numLabelSegments_ += labelSegments_.size(); for (auto& segment : labelSegments_) {
numOutputSegments_ += outputSegments_.size(); if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_;
}
for (auto& segment : outputSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_;
}
} }
void getSegments(int* label, int length, std::vector<Segment>& segments) { void getSegments(int* label, int length, std::vector<Segment>& segments) {
......
...@@ -433,8 +433,12 @@ message EvaluatorConfig { ...@@ -433,8 +433,12 @@ message EvaluatorConfig {
repeated string input_layers = 3; repeated string input_layers = 3;
// Used by ChunkEvaluator // Used by ChunkEvaluator
optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES" // one of "IOB", "IOE", "IOBES"
optional int32 num_chunk_types = 5; // number of chunk types other than "other" optional string chunk_scheme = 4;
// number of chunk types other than "other"
optional int32 num_chunk_types = 5;
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
// For multi binary labels: true if output > classification_threshold // For multi binary labels: true if output > classification_threshold
...@@ -453,6 +457,8 @@ message EvaluatorConfig { ...@@ -453,6 +457,8 @@ message EvaluatorConfig {
// whether to delimit the sequence in the seq_text_printer // whether to delimit the sequence in the seq_text_printer
optional bool delimited = 11 [default = true]; optional bool delimited = 11 [default = true];
// NOTE: 12 has been occupied by excluded_chunk_types
} }
message LinkConfig { message LinkConfig {
......
...@@ -1240,7 +1240,8 @@ def Evaluator( ...@@ -1240,7 +1240,8 @@ def Evaluator(
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None, ): delimited=None,
excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
evaluator.type = type evaluator.type = type
evaluator.name = MakeLayerNameInSubmodel(name) evaluator.name = MakeLayerNameInSubmodel(name)
...@@ -1269,6 +1270,9 @@ def Evaluator( ...@@ -1269,6 +1270,9 @@ def Evaluator(
if delimited is not None: if delimited is not None:
evaluator.delimited = delimited evaluator.delimited = delimited
if excluded_chunk_types:
evaluator.excluded_chunk_types.extend(excluded_chunk_types)
class LayerBase(object): class LayerBase(object):
def __init__( def __init__(
......
...@@ -57,19 +57,21 @@ def evaluator(*attrs): ...@@ -57,19 +57,21 @@ def evaluator(*attrs):
return impl return impl
def evaluator_base(input, def evaluator_base(
type, input,
label=None, type,
weight=None, label=None,
name=None, weight=None,
chunk_scheme=None, name=None,
num_chunk_types=None, chunk_scheme=None,
classification_threshold=None, num_chunk_types=None,
positive_label=None, classification_threshold=None,
dict_file=None, positive_label=None,
result_file=None, dict_file=None,
num_results=None, result_file=None,
delimited=None): num_results=None,
delimited=None,
excluded_chunk_types=None, ):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -127,7 +129,8 @@ def evaluator_base(input, ...@@ -127,7 +129,8 @@ def evaluator_base(input,
positive_label=positive_label, positive_label=positive_label,
dict_file=dict_file, dict_file=dict_file,
result_file=result_file, result_file=result_file,
delimited=delimited) delimited=delimited,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
...@@ -330,7 +333,8 @@ def chunk_evaluator( ...@@ -330,7 +333,8 @@ def chunk_evaluator(
label, label,
chunk_scheme, chunk_scheme,
num_chunk_types, num_chunk_types,
name=None, ): name=None,
excluded_chunk_types=None, ):
""" """
Chunk evaluator is used to evaluate segment labelling accuracy for a Chunk evaluator is used to evaluate segment labelling accuracy for a
sequence. It calculates the chunk detection F1 score. sequence. It calculates the chunk detection F1 score.
...@@ -376,6 +380,8 @@ def chunk_evaluator( ...@@ -376,6 +380,8 @@ def chunk_evaluator(
:param num_chunk_types: number of chunk types other than "other" :param num_chunk_types: number of chunk types other than "other"
:param name: The Evaluator name, it is optional. :param name: The Evaluator name, it is optional.
:type name: basename|None :type name: basename|None
:param excluded_chunk_types: chunks of these types are not considered
:type excluded_chunk_types: list of integer|[]
""" """
evaluator_base( evaluator_base(
name=name, name=name,
...@@ -383,7 +389,8 @@ def chunk_evaluator( ...@@ -383,7 +389,8 @@ def chunk_evaluator(
input=input, input=input,
label=label, label=label,
chunk_scheme=chunk_scheme, chunk_scheme=chunk_scheme,
num_chunk_types=num_chunk_types) num_chunk_types=num_chunk_types,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_UTILS) @evaluator(EvaluatorAttribute.FOR_UTILS)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册