Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8cdb42c2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8cdb42c2
编写于
11月 02, 2017
作者:
Y
Yang yaming
提交者:
GitHub
11月 02, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5111 from pkuyym/fix-5070
Add PrecisionRecall Op
上级
69011c18
970613fc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
513 addition
and
0 deletion
+513
-0
paddle/operators/precision_recall_op.cc
paddle/operators/precision_recall_op.cc
+179
-0
paddle/operators/precision_recall_op.h
paddle/operators/precision_recall_op.h
+161
-0
python/paddle/v2/framework/tests/test_precision_recall_op.py
python/paddle/v2/framework/tests/test_precision_recall_op.py
+173
-0
未找到文件。
paddle/operators/precision_recall_op.cc
0 → 100644
浏览文件 @
8cdb42c2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/precision_recall_op.h"
namespace
paddle
{
namespace
operators
{
class
PrecisionRecallOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MaxProbs"
),
"Input(MaxProbs) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Indices"
),
"Input(Indices) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchMetrics"
),
"Output(BatchMetrics) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"AccumMetrics"
),
"Output(AccumMetrics) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"AccumStatesInfo"
),
"Output(AccumStatesInfo) should not be null."
);
int64_t
cls_num
=
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"class_number"
));
auto
max_probs_dims
=
ctx
->
GetInputDim
(
"MaxProbs"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Labels"
);
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
1
],
1
,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1]."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Indices"
),
max_probs_dims
,
"The shape of Input(Indices) should be [batch_size, 1]."
);
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
0
],
labels_dims
[
0
],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same."
);
PADDLE_ENFORCE_EQ
(
labels_dims
[
1
],
1
,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1."
);
if
(
ctx
->
HasInput
(
"Weights"
))
{
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
PADDLE_ENFORCE_EQ
(
weights_dims
,
framework
::
make_ddim
({
max_probs_dims
[
0
],
1
}),
"The shape of Input(Weights) should be "
"[batch_size, 1]."
);
}
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
cls_num
,
4
}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4]."
);
}
// Layouts of BatchMetrics and AccumMetrics both are:
// [
// macro average precision, macro average recall, macro average F1 score,
// micro average precision, micro average recall, micro average F1 score
// ]
ctx
->
SetOutputDim
(
"BatchMetrics"
,
{
6
});
ctx
->
SetOutputDim
(
"AccumMetrics"
,
{
6
});
// Shape of AccumStatesInfo is [class_number, 4]
// The layout of each row is:
// [ TP, FP, TN, FN ]
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
cls_num
,
4
});
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
());
}
};
class
PrecisionRecallOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
PrecisionRecallOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"MaxProbs"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability "
"of an instance which computed by the previous top_k (k=1) "
"operator."
);
AddInput
(
"Indices"
,
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the corresponding "
"index which computed by the previous top_k (k=1) operator."
);
AddInput
(
"Labels"
,
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each element is a label and the "
"value should be in [0, class_number - 1]."
);
AddInput
(
"Weights"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
"where N is the batch size. This input is optional. If provided, "
"weight of instance would be considered when computing metrics."
)
.
AsDispensable
();
AddInput
(
"StatesInfo"
,
"(Tensor, default Tensor<int>), a 2-D tensor with shape D x 4, "
"where D is the number of classes. This input is optional. If "
"provided, current state will be accumulated to this state and "
"the accumulation state will be as the output state."
)
.
AsDispensable
();
AddOutput
(
"BatchMetrics"
,
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
"This output tensor contains metrics for current batch data."
"The layout is [macro average precision, macro average recall, "
"macro f1 score, micro average precision, micro average recall, "
"micro f1 score]"
);
AddOutput
(
"AccumMetrics"
,
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
"This output tensor contains metrics for accumulated data."
"The layout is [macro average precision, macro average recall, "
"macro f1 score, micro average precision, micro average recall, "
"micro f1 score]"
);
AddOutput
(
"AccumStatesInfo"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape D x 4, "
"where D is equal to class number. This output tensor contains "
"accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, "
"true negatives, false negatives]."
);
AddAttr
<
int
>
(
"class_number"
,
"Number of classes to be evaluated."
);
AddComment
(
R"DOC(
When given 'Input(Indices)' and 'Input(Labels)', this operator can be used
to compute various metrics including:
- macro average precision
- macro average recall
- macro f1 score
- micro average precision
- micro average recall
- micro f1 score
To compute the above metrics, we need to do statistics for true positives,
false positives and false negatives. Here count of true negatives is not
necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class_number, 4]. Each row of a
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives),
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
calculated by given weight instead of instance count.
This operator also supports metrics computing for cross-batch situation. To
achieve this, 'Input(StatesInfo)' should be provided. State of current batch
data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)'
is the accumulation state.
'Output(BatchMetrics)' is metrics of current batch data while
'Output(AccumStatesInfo)' is metrics of accumulation data.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
precision_recall
,
ops
::
PrecisionRecallOp
,
ops
::
PrecisionRecallOpMaker
);
REGISTER_OP_CPU_KERNEL
(
precision_recall
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/precision_recall_op.h
0 → 100644
浏览文件 @
8cdb42c2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
enum
StateVariable
{
TP
=
0
,
FP
,
TN
,
FN
};
template
<
typename
Place
,
typename
T
>
class
PrecisionRecallKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in0
=
ctx
.
Input
<
Tensor
>
(
"Indices"
);
auto
*
in1
=
ctx
.
Input
<
Tensor
>
(
"Labels"
);
auto
*
in2
=
ctx
.
Input
<
Tensor
>
(
"Weights"
);
auto
*
in3
=
ctx
.
Input
<
Tensor
>
(
"StatesInfo"
);
auto
*
out0
=
ctx
.
Output
<
Tensor
>
(
"BatchMetrics"
);
auto
*
out1
=
ctx
.
Output
<
Tensor
>
(
"AccumMetrics"
);
auto
*
out2
=
ctx
.
Output
<
Tensor
>
(
"AccumStatesInfo"
);
const
int
*
ids_data
=
in0
->
data
<
int
>
();
const
int
*
labels_data
=
in1
->
data
<
int
>
();
size_t
cls_num
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"class_number"
));
const
T
*
weights_data
=
in2
?
in2
->
data
<
T
>
()
:
nullptr
;
const
T
*
states_data
=
in3
?
in3
->
data
<
T
>
()
:
nullptr
;
double
*
batch_metrics_data
=
out0
->
mutable_data
<
double
>
(
ctx
.
GetPlace
());
double
*
accum_metrics_data
=
out1
->
mutable_data
<
double
>
(
ctx
.
GetPlace
());
out2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
accum_states
=
EigenMatrix
<
T
>::
From
(
*
out2
);
accum_states
.
setZero
();
T
*
accum_states_data
=
out2
->
data
<
T
>
();
size_t
sample_num
=
in0
->
dims
()[
0
];
size_t
state_var_num
=
4
;
// TP FP TN FN
// get states info for current batch
for
(
size_t
i
=
0
;
i
<
sample_num
;
++
i
)
{
size_t
idx
=
ids_data
[
i
];
size_t
label
=
labels_data
[
i
];
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
cls_num
,
"Class index of each instance should be in "
"[0, class_number)."
);
PADDLE_ENFORCE
(
label
>=
0
&&
label
<
cls_num
,
"Label of each instance should be in [0, class_number)."
);
T
w
=
weights_data
?
weights_data
[
i
]
:
1.0
;
if
(
idx
==
label
)
{
accum_states_data
[
idx
*
state_var_num
+
TP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cls_num
;
++
j
)
{
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
}
accum_states_data
[
idx
*
state_var_num
+
TN
]
-=
w
;
}
else
{
accum_states_data
[
label
*
state_var_num
+
FN
]
+=
w
;
accum_states_data
[
idx
*
state_var_num
+
FP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cls_num
;
++
j
)
{
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
}
accum_states_data
[
idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
label
*
state_var_num
+
TN
]
-=
w
;
}
}
ComputeMetrics
(
accum_states_data
,
batch_metrics_data
,
state_var_num
,
cls_num
);
if
(
states_data
)
{
for
(
size_t
i
=
0
;
i
<
cls_num
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
state_var_num
;
++
j
)
{
size_t
idx
=
i
*
state_var_num
+
j
;
accum_states_data
[
idx
]
+=
states_data
[
idx
];
}
}
}
ComputeMetrics
(
accum_states_data
,
accum_metrics_data
,
state_var_num
,
cls_num
);
}
// expose to be reused
static
inline
T
CalcPrecision
(
T
tp_count
,
T
fp_count
)
{
if
(
tp_count
>
0.0
||
fp_count
>
0.0
)
{
return
tp_count
/
(
tp_count
+
fp_count
);
}
return
1.0
;
}
static
inline
T
CalcRecall
(
T
tp_count
,
T
fn_count
)
{
if
(
tp_count
>
0.0
||
fn_count
>
0.0
)
{
return
tp_count
/
(
tp_count
+
fn_count
);
}
return
1.0
;
}
static
inline
T
CalcF1Score
(
T
precision
,
T
recall
)
{
if
(
precision
>
0.0
||
recall
>
0.0
)
{
return
2
*
precision
*
recall
/
(
precision
+
recall
);
}
return
0.0
;
}
protected:
void
ComputeMetrics
(
const
T
*
states_data
,
double
*
metrics_data
,
size_t
state_var_num
,
size_t
cls_num
)
const
{
T
total_tp_count
=
0
;
T
total_fp_count
=
0
;
T
total_fn_count
=
0
;
T
macro_avg_precision
=
0.0
;
T
macro_avg_recall
=
0.0
;
for
(
size_t
i
=
0
;
i
<
cls_num
;
++
i
)
{
T
tp_count
=
states_data
[
i
*
state_var_num
+
TP
];
T
fp_count
=
states_data
[
i
*
state_var_num
+
FP
];
T
fn_count
=
states_data
[
i
*
state_var_num
+
FN
];
total_tp_count
+=
tp_count
;
total_fp_count
+=
fp_count
;
total_fn_count
+=
fn_count
;
macro_avg_precision
+=
CalcPrecision
(
tp_count
,
fp_count
);
macro_avg_recall
+=
CalcRecall
(
tp_count
,
fn_count
);
}
macro_avg_precision
/=
cls_num
;
macro_avg_recall
/=
cls_num
;
T
macro_f1_score
=
CalcF1Score
(
macro_avg_precision
,
macro_avg_recall
);
T
micro_avg_precision
=
CalcPrecision
(
total_tp_count
,
total_fp_count
);
T
micro_avg_recall
=
CalcRecall
(
total_tp_count
,
total_fn_count
);
T
micro_f1_score
=
CalcF1Score
(
micro_avg_precision
,
micro_avg_recall
);
// fill metrics data
metrics_data
[
0
]
=
macro_avg_precision
;
metrics_data
[
1
]
=
macro_avg_recall
;
metrics_data
[
2
]
=
macro_f1_score
;
metrics_data
[
3
]
=
micro_avg_precision
;
metrics_data
[
4
]
=
micro_avg_recall
;
metrics_data
[
5
]
=
micro_f1_score
;
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/v2/framework/tests/test_precision_recall_op.py
0 → 100644
浏览文件 @
8cdb42c2
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
calc_precision
(
tp_count
,
fp_count
):
if
tp_count
>
0.0
or
fp_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fp_count
)
return
1.0
def
calc_recall
(
tp_count
,
fn_count
):
if
tp_count
>
0.0
or
fn_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fn_count
)
return
1.0
def
calc_f1_score
(
precision
,
recall
):
if
precision
>
0.0
or
recall
>
0.0
:
return
2
*
precision
*
recall
/
(
precision
+
recall
)
return
0.0
def
get_states
(
idxs
,
labels
,
cls_num
,
weights
=
None
):
ins_num
=
idxs
.
shape
[
0
]
# TP FP TN FN
states
=
np
.
zeros
((
cls_num
,
4
)).
astype
(
'float32'
)
for
i
in
xrange
(
ins_num
):
w
=
weights
[
i
]
if
weights
is
not
None
else
1.0
idx
=
idxs
[
i
][
0
]
label
=
labels
[
i
][
0
]
if
idx
==
label
:
states
[
idx
][
0
]
+=
w
for
j
in
xrange
(
cls_num
):
states
[
j
][
2
]
+=
w
states
[
idx
][
2
]
-=
w
else
:
states
[
label
][
3
]
+=
w
states
[
idx
][
1
]
+=
w
for
j
in
xrange
(
cls_num
):
states
[
j
][
2
]
+=
w
states
[
label
][
2
]
-=
w
states
[
idx
][
2
]
-=
w
return
states
def
compute_metrics
(
states
,
cls_num
):
total_tp_count
=
0.0
total_fp_count
=
0.0
total_fn_count
=
0.0
macro_avg_precision
=
0.0
macro_avg_recall
=
0.0
for
i
in
xrange
(
cls_num
):
total_tp_count
+=
states
[
i
][
0
]
total_fp_count
+=
states
[
i
][
1
]
total_fn_count
+=
states
[
i
][
3
]
macro_avg_precision
+=
calc_precision
(
states
[
i
][
0
],
states
[
i
][
1
])
macro_avg_recall
+=
calc_recall
(
states
[
i
][
0
],
states
[
i
][
3
])
metrics
=
[]
macro_avg_precision
/=
cls_num
macro_avg_recall
/=
cls_num
metrics
.
append
(
macro_avg_precision
)
metrics
.
append
(
macro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
macro_avg_precision
,
macro_avg_recall
))
micro_avg_precision
=
calc_precision
(
total_tp_count
,
total_fp_count
)
metrics
.
append
(
micro_avg_precision
)
micro_avg_recall
=
calc_recall
(
total_tp_count
,
total_fn_count
)
metrics
.
append
(
micro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
micro_avg_precision
,
micro_avg_recall
))
return
np
.
array
(
metrics
).
astype
(
'float32'
)
class
TestPrecisionRecallOp_0
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
idxs
,
labels
,
cls_num
)
metrics
=
compute_metrics
(
states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'MaxProbs'
:
max_probs
,
'Indices'
:
idxs
,
'Labels'
:
labels
}
self
.
outputs
=
{
'BatchMetrics'
:
metrics
,
'AccumMetrics'
:
metrics
,
'AccumStatesInfo'
:
states
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestPrecisionRecallOp_1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
idxs
,
labels
,
cls_num
,
weights
)
metrics
=
compute_metrics
(
states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'MaxProbs'
:
max_probs
,
'Indices'
:
idxs
,
'Labels'
:
labels
,
'Weights'
:
weights
}
self
.
outputs
=
{
'BatchMetrics'
:
metrics
,
'AccumMetrics'
:
metrics
,
'AccumStatesInfo'
:
states
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestPrecisionRecallOp_2
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
np
.
random
.
randint
(
0
,
30
,
(
cls_num
,
4
)).
astype
(
'float32'
)
accum_states
=
get_states
(
idxs
,
labels
,
cls_num
,
weights
)
batch_metrics
=
compute_metrics
(
accum_states
,
cls_num
)
accum_states
+=
states
accum_metrics
=
compute_metrics
(
accum_states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'MaxProbs'
:
max_probs
,
'Indices'
:
idxs
,
'Labels'
:
labels
,
'Weights'
:
weights
,
'StatesInfo'
:
states
}
self
.
outputs
=
{
'BatchMetrics'
:
batch_metrics
,
'AccumMetrics'
:
accum_metrics
,
'AccumStatesInfo'
:
accum_states
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录