提交 d2b10cc0 编写于 作者: Y yangyaming

Refine doc and fix data type of metrics.

上级 97bfc0df
......@@ -136,9 +136,9 @@ to compute various metrics including:
- micro average recall
- micro f1 score
To compute the above metrics, we need to statistic counts for true positives,
To compute the above metrics, we need to do statistics for true positives,
false positives and false negatives. Here count of true negatives is not
necessary, but statisticing it may provide potential usage and the cost is
necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a
......
......@@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
const int* labels_data = in1->data<int>();
const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr;
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
T* accum_metrics_data = out1->mutable_data<T>(ctx.GetPlace());
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
double* accum_metrics_data = out1->mutable_data<double>(ctx.GetPlace());
out2->mutable_data<T>(ctx.GetPlace());
auto accum_states = EigenMatrix<T>::From(*out2);
accum_states.setZero();
......@@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
}
protected:
void ComputeMetrics(const T* states_data, T* metrics_data,
void ComputeMetrics(const T* states_data, double* metrics_data,
size_t state_var_num, size_t class_dim) const {
T total_tp_count = 0;
T total_fp_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册