From 8fe3a3aa73be9b7d1f748e3809dff9f5323be719 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Tue, 20 Dec 2016 16:46:42 +0800 Subject: [PATCH] Add excluded_chunk_types to ChunkEvaluator The chunks of types in excluded_chunk_types will not be counted in ChunkEvaluator. This is useful for tasks such as SRL, in which chunks of type V (verb) will not be taken into account in evaluation. --- paddle/gserver/evaluators/ChunkEvaluator.cpp | 17 ++++++-- proto/ModelConfig.proto | 10 ++++- python/paddle/trainer/config_parser.py | 6 ++- .../trainer_config_helpers/evaluators.py | 39 +++++++++++-------- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 3d8af5bcd4..15e0e95206 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/math/Vector.h" @@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator { std::vector labelSegments_; std::vector outputSegments_; + std::set excludedChunkTypes_; public: virtual void init(const EvaluatorConfig& config) { @@ -105,6 +107,10 @@ public: } CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config"; otherChunkType_ = numChunkTypes_ = config.num_chunk_types(); + + // the chunks of types in excludedChunkTypes_ will not be counted + auto& tmp = config.excluded_chunk_types(); + excludedChunkTypes_.insert(tmp.begin(), tmp.end()); } virtual void start() { @@ -157,7 +163,8 @@ public: size_t i = 0, j = 0; while (i < outputSegments_.size() && j < labelSegments_.size()) { if (outputSegments_[i] == labelSegments_[j]) { - ++numCorrect_; + if (excludedChunkTypes_.count(outputSegments_[i].type) != 1) + ++numCorrect_; } if (outputSegments_[i].end < labelSegments_[j].end) { ++i; @@ -168,8 +175,12 @@ public: ++j; } } - numLabelSegments_ += labelSegments_.size(); - numOutputSegments_ += outputSegments_.size(); + for (auto& segment : labelSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_; + } + for (auto& segment : outputSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_; + } } void getSegments(int* label, int length, std::vector& segments) { diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 552af71e76..e24ed21fbb 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -433,8 +433,12 @@ message EvaluatorConfig { repeated string input_layers = 3; // Used by ChunkEvaluator - optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES" - optional int32 num_chunk_types = 5; // number of chunk types other than "other" + // one of "IOB", "IOE", "IOBES" + optional string chunk_scheme = 4; + // number of chunk types other than "other" + optional int32 num_chunk_types = 5; + // chunk of these types are not counted + repeated int32 excluded_chunk_types = 12; // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // For multi binary labels: true if output > classification_threshold @@ -453,6 +457,8 @@ message EvaluatorConfig { // whether to delimit the sequence in the seq_text_printer optional bool delimited = 11 [default = true]; + + // NOTE: 12 has been occupied by excluded_chunk_types } message LinkConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ea3e4308fe..39892d0533 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1240,7 +1240,8 @@ def Evaluator( dict_file=None, result_file=None, num_results=None, - delimited=None, ): + delimited=None, + excluded_chunk_types=None, ): evaluator = g_config.model_config.evaluators.add() evaluator.type = type evaluator.name = MakeLayerNameInSubmodel(name) @@ -1269,6 +1270,9 @@ def Evaluator( if delimited is not None: evaluator.delimited = delimited + if excluded_chunk_types: + evaluator.excluded_chunk_types.extend(excluded_chunk_types) + class LayerBase(object): def __init__( diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 3e0e88972c..731e30d367 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -57,19 +57,21 @@ def evaluator(*attrs): return impl -def evaluator_base(input, - type, - label=None, - weight=None, - name=None, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - delimited=None): +def evaluator_base( + input, + type, + label=None, + weight=None, + name=None, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + delimited=None, + excluded_chunk_types=None, ): """ Evaluator will evaluate the network status while training/testing. @@ -127,7 +129,8 @@ def evaluator_base(input, positive_label=positive_label, dict_file=dict_file, result_file=result_file, - delimited=delimited) + delimited=delimited, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @@ -330,7 +333,8 @@ def chunk_evaluator( label, chunk_scheme, num_chunk_types, - name=None, ): + name=None, + excluded_chunk_types=None, ): """ Chunk evaluator is used to evaluate segment labelling accuracy for a sequence. It calculates the chunk detection F1 score. @@ -376,6 +380,8 @@ def chunk_evaluator( :param num_chunk_types: number of chunk types other than "other" :param name: The Evaluator name, it is optional. :type name: basename|None + :param excluded_chunk_types: chunks of these types are not considered + :type excluded_chunk_types: list of integer|[] """ evaluator_base( name=name, @@ -383,7 +389,8 @@ def chunk_evaluator( input=input, label=label, chunk_scheme=chunk_scheme, - num_chunk_types=num_chunk_types) + num_chunk_types=num_chunk_types, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_UTILS) -- GitLab