diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 3d8af5bcd419e76fb2026eddc95dc409a33c9d92..13f02e51fe9e3831103982130bfdaa3255e1d174 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/math/Vector.h" @@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator { std::vector labelSegments_; std::vector outputSegments_; + std::set excludedChunkTypes_; public: virtual void init(const EvaluatorConfig& config) { @@ -105,6 +107,10 @@ public: } CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config"; otherChunkType_ = numChunkTypes_ = config.num_chunk_types(); + + // the chunks of types in excludedChunkTypes_ will not be counted + auto& tmp = config.excluded_chunk_types(); + excludedChunkTypes_.insert(tmp.begin(), tmp.end()); } virtual void start() { @@ -156,7 +162,8 @@ public: getSegments(label, length, labelSegments_); size_t i = 0, j = 0; while (i < outputSegments_.size() && j < labelSegments_.size()) { - if (outputSegments_[i] == labelSegments_[j]) { + if (outputSegments_[i] == labelSegments_[j] && + excludedChunkTypes_.count(outputSegments_[i].type) != 1) { ++numCorrect_; } if (outputSegments_[i].end < labelSegments_[j].end) { @@ -168,8 +175,12 @@ public: ++j; } } - numLabelSegments_ += labelSegments_.size(); - numOutputSegments_ += outputSegments_.size(); + for (auto& segment : labelSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_; + } + for (auto& segment : outputSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_; + } } void getSegments(int* label, int length, std::vector& segments) { diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 552af71e76e5adf27f35bb5ad6fd8a69c71df0f1..be4d0041f91cf7d0306d14338b43bb25e052fd58 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -433,8 +433,10 @@ message EvaluatorConfig { repeated string input_layers = 3; // Used by ChunkEvaluator - optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES" - optional int32 num_chunk_types = 5; // number of chunk types other than "other" + // one of "IOB", "IOE", "IOBES" + optional string chunk_scheme = 4; + // number of chunk types other than "other" + optional int32 num_chunk_types = 5; // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // For multi binary labels: true if output > classification_threshold @@ -453,6 +455,10 @@ message EvaluatorConfig { // whether to delimit the sequence in the seq_text_printer optional bool delimited = 11 [default = true]; + + // Used by ChunkEvaluator + // chunk of these types are not counted + repeated int32 excluded_chunk_types = 12; } message LinkConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ea3e4308fe05be464c3e8c6b84d8b7be8a30c016..39892d0533aab468d808274146ae1a0f72170495 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1240,7 +1240,8 @@ def Evaluator( dict_file=None, result_file=None, num_results=None, - delimited=None, ): + delimited=None, + excluded_chunk_types=None, ): evaluator = g_config.model_config.evaluators.add() evaluator.type = type evaluator.name = MakeLayerNameInSubmodel(name) @@ -1269,6 +1270,9 @@ def Evaluator( if delimited is not None: evaluator.delimited = delimited + if excluded_chunk_types: + evaluator.excluded_chunk_types.extend(excluded_chunk_types) + class LayerBase(object): def __init__( diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 3e0e88972c58e8c853e79e21f839943ae4b027d6..bd247ea9af9d8dfb2d476cdc62638bd65c11add5 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -57,19 +57,21 @@ def evaluator(*attrs): return impl -def evaluator_base(input, - type, - label=None, - weight=None, - name=None, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - delimited=None): +def evaluator_base( + input, + type, + label=None, + weight=None, + name=None, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + delimited=None, + excluded_chunk_types=None, ): """ Evaluator will evaluate the network status while training/testing. @@ -127,7 +129,8 @@ def evaluator_base(input, positive_label=positive_label, dict_file=dict_file, result_file=result_file, - delimited=delimited) + delimited=delimited, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @@ -330,7 +333,8 @@ def chunk_evaluator( label, chunk_scheme, num_chunk_types, - name=None, ): + name=None, + excluded_chunk_types=None, ): """ Chunk evaluator is used to evaluate segment labelling accuracy for a sequence. It calculates the chunk detection F1 score. @@ -376,6 +380,8 @@ def chunk_evaluator( :param num_chunk_types: number of chunk types other than "other" :param name: The Evaluator name, it is optional. :type name: basename|None + :param excluded_chunk_types: chunks of these types are not considered + :type excluded_chunk_types: list of integer|None """ evaluator_base( name=name, @@ -383,7 +389,8 @@ def chunk_evaluator( input=input, label=label, chunk_scheme=chunk_scheme, - num_chunk_types=num_chunk_types) + num_chunk_types=num_chunk_types, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_UTILS)