未验证 提交 d08791d1 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix CPPLint issues with Chunk_eval_op (#9964)

上级 8352f938
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/chunk_eval_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -36,11 +39,11 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
};
void GetSegments(const int64_t* label, int length,
std::vector<Segment>& segments, int num_chunk_types,
std::vector<Segment>* segments, int num_chunk_types,
int num_tag_types, int other_chunk_type, int tag_begin,
int tag_inside, int tag_end, int tag_single) const {
segments.clear();
segments.reserve(length);
segments->clear();
segments->reserve(length);
int chunk_start = 0;
bool in_chunk = false;
int tag = -1;
......@@ -58,7 +61,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
i - 1, // end
prev_type,
};
segments.push_back(segment);
segments->push_back(segment);
in_chunk = false;
}
if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type,
......@@ -73,7 +76,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
length - 1, // end
type,
};
segments.push_back(segment);
segments->push_back(segment);
}
}
......@@ -177,8 +180,8 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
for (int i = 0; i < num_sequences; ++i) {
int seq_length = lod[0][i + 1] - lod[0][i];
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
output_segments, label_segments, *num_infer_chunks_data,
*num_label_chunks_data, *num_correct_chunks_data,
&output_segments, &label_segments, num_infer_chunks_data,
num_label_chunks_data, num_correct_chunks_data,
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
tag_inside, tag_end, tag_single, excluded_chunk_types);
}
......@@ -197,10 +200,10 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
}
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
std::vector<Segment>& output_segments,
std::vector<Segment>& label_segments,
int64_t& num_output_segments, int64_t& num_label_segments,
int64_t& num_correct, int num_chunk_types, int num_tag_types,
std::vector<Segment>* output_segments,
std::vector<Segment>* label_segments,
int64_t* num_output_segments, int64_t* num_label_segments,
int64_t* num_correct, int num_chunk_types, int num_tag_types,
int other_chunk_type, int tag_begin, int tag_inside,
int tag_end, int tag_single,
const std::set<int>& excluded_chunk_types) const {
......@@ -209,25 +212,29 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
GetSegments(label, length, label_segments, num_chunk_types, num_tag_types,
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
size_t i = 0, j = 0;
while (i < output_segments.size() && j < label_segments.size()) {
if (output_segments[i] == label_segments[j] &&
excluded_chunk_types.count(output_segments[i].type) != 1) {
++num_correct;
while (i < output_segments->size() && j < label_segments->size()) {
if (output_segments->at(i) == label_segments->at(j) &&
excluded_chunk_types.count(output_segments->at(i).type) != 1) {
++(*num_correct);
}
if (output_segments[i].end < label_segments[j].end) {
if (output_segments->at(i).end < label_segments->at(j).end) {
++i;
} else if (output_segments[i].end > label_segments[j].end) {
} else if (output_segments->at(i).end > label_segments->at(j).end) {
++j;
} else {
++i;
++j;
}
}
for (auto& segment : label_segments) {
if (excluded_chunk_types.count(segment.type) != 1) ++num_label_segments;
for (auto& segment : (*label_segments)) {
if (excluded_chunk_types.count(segment.type) != 1) {
++(*num_label_segments);
}
}
for (auto& segment : output_segments) {
if (excluded_chunk_types.count(segment.type) != 1) ++num_output_segments;
for (auto& segment : (*output_segments)) {
if (excluded_chunk_types.count(segment.type) != 1) {
++(*num_output_segments);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册