提交 8fe3a3aa 编写于 作者: P Peng Li

Add excluded_chunk_types to ChunkEvaluator

The chunks of types in excluded_chunk_types will not be counted in
ChunkEvaluator. This is useful for tasks such as SRL, in which chunks of
type V (verb) will not be taken into account in evaluation.
上级 8a42a549
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <set>
#include <vector>
#include "paddle/math/Vector.h"
......@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std::vector<Segment> labelSegments_;
std::vector<Segment> outputSegments_;
std::set<int> excludedChunkTypes_;
public:
virtual void init(const EvaluatorConfig& config) {
......@@ -105,6 +107,10 @@ public:
}
CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config";
otherChunkType_ = numChunkTypes_ = config.num_chunk_types();
// the chunks of types in excludedChunkTypes_ will not be counted
auto& tmp = config.excluded_chunk_types();
excludedChunkTypes_.insert(tmp.begin(), tmp.end());
}
virtual void start() {
......@@ -157,6 +163,7 @@ public:
size_t i = 0, j = 0;
while (i < outputSegments_.size() && j < labelSegments_.size()) {
if (outputSegments_[i] == labelSegments_[j]) {
if (excludedChunkTypes_.count(outputSegments_[i].type) != 1)
++numCorrect_;
}
if (outputSegments_[i].end < labelSegments_[j].end) {
......@@ -168,8 +175,12 @@ public:
++j;
}
}
numLabelSegments_ += labelSegments_.size();
numOutputSegments_ += outputSegments_.size();
for (auto& segment : labelSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_;
}
for (auto& segment : outputSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_;
}
}
void getSegments(int* label, int length, std::vector<Segment>& segments) {
......
......@@ -433,8 +433,12 @@ message EvaluatorConfig {
repeated string input_layers = 3;
// Used by ChunkEvaluator
optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES"
optional int32 num_chunk_types = 5; // number of chunk types other than "other"
// one of "IOB", "IOE", "IOBES"
optional string chunk_scheme = 4;
// number of chunk types other than "other"
optional int32 num_chunk_types = 5;
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
// For multi binary labels: true if output > classification_threshold
......@@ -453,6 +457,8 @@ message EvaluatorConfig {
// whether to delimit the sequence in the seq_text_printer
optional bool delimited = 11 [default = true];
// NOTE: 12 has been occupied by excluded_chunk_types
}
message LinkConfig {
......
......@@ -1240,7 +1240,8 @@ def Evaluator(
dict_file=None,
result_file=None,
num_results=None,
delimited=None, ):
delimited=None,
excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add()
evaluator.type = type
evaluator.name = MakeLayerNameInSubmodel(name)
......@@ -1269,6 +1270,9 @@ def Evaluator(
if delimited is not None:
evaluator.delimited = delimited
if excluded_chunk_types:
evaluator.excluded_chunk_types.extend(excluded_chunk_types)
class LayerBase(object):
def __init__(
......
......@@ -57,7 +57,8 @@ def evaluator(*attrs):
return impl
def evaluator_base(input,
def evaluator_base(
input,
type,
label=None,
weight=None,
......@@ -69,7 +70,8 @@ def evaluator_base(input,
dict_file=None,
result_file=None,
num_results=None,
delimited=None):
delimited=None,
excluded_chunk_types=None, ):
"""
Evaluator will evaluate the network status while training/testing.
......@@ -127,7 +129,8 @@ def evaluator_base(input,
positive_label=positive_label,
dict_file=dict_file,
result_file=result_file,
delimited=delimited)
delimited=delimited,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
......@@ -330,7 +333,8 @@ def chunk_evaluator(
label,
chunk_scheme,
num_chunk_types,
name=None, ):
name=None,
excluded_chunk_types=None, ):
"""
Chunk evaluator is used to evaluate segment labelling accuracy for a
sequence. It calculates the chunk detection F1 score.
......@@ -376,6 +380,8 @@ def chunk_evaluator(
:param num_chunk_types: number of chunk types other than "other"
:param name: The Evaluator name, it is optional.
:type name: basename|None
:param excluded_chunk_types: chunks of these types are not considered
:type excluded_chunk_types: list of integer|[]
"""
evaluator_base(
name=name,
......@@ -383,7 +389,8 @@ def chunk_evaluator(
input=input,
label=label,
chunk_scheme=chunk_scheme,
num_chunk_types=num_chunk_types)
num_chunk_types=num_chunk_types,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_UTILS)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册