Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
970613fc
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
970613fc
编写于
11月 01, 2017
作者:
Y
yangyaming
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine and follow comments.
上级
d2b10cc0
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
115 addition
and
98 deletion
+115
-98
paddle/operators/precision_recall_op.cc
paddle/operators/precision_recall_op.cc
+35
-27
paddle/operators/precision_recall_op.h
paddle/operators/precision_recall_op.h
+27
-27
python/paddle/v2/framework/tests/test_precision_recall_op.py
python/paddle/v2/framework/tests/test_precision_recall_op.py
+53
-44
未找到文件。
paddle/operators/precision_recall_op.cc
浏览文件 @
970613fc
...
...
@@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Predictions"
),
"Input(Predictions) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MaxProbs"
),
"Input(MaxProbs) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Indices"
),
"Input(Indices) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchMetrics"
),
...
...
@@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"AccumStatesInfo"
),
"Output(AccumStatesInfo) should not be null."
);
auto
predictions_dims
=
ctx
->
GetInputDim
(
"Predictions"
);
int64_t
cls_num
=
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"class_number"
));
auto
max_probs_dims
=
ctx
->
GetInputDim
(
"MaxProbs"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Labels"
);
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
1
],
1
,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1]."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Indices"
),
max_probs_dims
,
"The shape of Input(Indices) should be [batch_size, 1]."
);
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
0
],
labels_dims
[
0
],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same."
);
PADDLE_ENFORCE_EQ
(
labels_dims
[
1
],
1
,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1."
);
if
(
ctx
->
HasInput
(
"Weights"
))
{
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
PADDLE_ENFORCE_EQ
(
weights_dims
,
framework
::
make_ddim
({
prediction
s_dims
[
0
],
1
}),
framework
::
make_ddim
({
max_prob
s_dims
[
0
],
1
}),
"The shape of Input(Weights) should be "
"[batch_size, 1]."
);
}
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
predictions_dims
[
1
],
4
}),
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
cls_num
,
4
}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4]."
);
}
PADDLE_ENFORCE_EQ
(
predictions_dims
[
0
],
labels_dims
[
0
],
"The 1st dimension of Input(Predictions) and "
"Input(Labels) both are batch_size and the shape should "
"be the same."
);
PADDLE_ENFORCE_EQ
(
labels_dims
[
1
],
1
,
"The 2nd dimension of Input(Labels) "
"contains instance label and the shape should be equal "
"to 1"
);
PADDLE_ENFORCE_GE
(
predictions_dims
[
1
],
1
,
"The shape of Input(Predictions)'s 2nd dimension is "
"equal to class number and should be at least 1."
);
// Layouts of BatchMetrics and AccumMetrics both are:
// [
...
...
@@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// Shape of AccumStatesInfo is [class_number, 4]
// The layout of each row is:
// [ TP, FP, TN, FN ]
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
predictions_dims
[
1
]
,
4
});
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
cls_num
,
4
});
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"
Prediction
s"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"
MaxProb
s"
)
->
type
());
}
};
...
...
@@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
PrecisionRecallOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Predictions"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"Each row contains probabilities for an instance which computed "
"by the previous operator."
);
AddInput
(
"MaxProbs"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability "
"of an instance which computed by the previous top_k (k=1) "
"operator."
);
AddInput
(
"Indices"
,
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the corresponding "
"index which computed by the previous top_k (k=1) operator."
);
AddInput
(
"Labels"
,
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each element is a label and the "
...
...
@@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, "
"true negatives, false negatives]."
);
AddAttr
<
int
>
(
"class_number"
,
"Number of classes to be evaluated."
);
AddComment
(
R"DOC(
When given 'Input(
Prediction
s)' and 'Input(Labels)', this operator can be used
When given 'Input(
Indice
s)' and 'Input(Labels)', this operator can be used
to compute various metrics including:
- macro average precision
- macro average recall
...
...
@@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not
necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class
number, 4]. Each row of a
We define state as a 2-D tensor with shape [class
_
number, 4]. Each row of a
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives),
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
...
...
paddle/operators/precision_recall_op.h
浏览文件 @
970613fc
...
...
@@ -30,7 +30,7 @@ template <typename Place, typename T>
class
PrecisionRecallKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in0
=
ctx
.
Input
<
Tensor
>
(
"
Prediction
s"
);
auto
*
in0
=
ctx
.
Input
<
Tensor
>
(
"
Indice
s"
);
auto
*
in1
=
ctx
.
Input
<
Tensor
>
(
"Labels"
);
auto
*
in2
=
ctx
.
Input
<
Tensor
>
(
"Weights"
);
auto
*
in3
=
ctx
.
Input
<
Tensor
>
(
"StatesInfo"
);
...
...
@@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto
*
out1
=
ctx
.
Output
<
Tensor
>
(
"AccumMetrics"
);
auto
*
out2
=
ctx
.
Output
<
Tensor
>
(
"AccumStatesInfo"
);
const
T
*
predictions_data
=
in0
->
data
<
T
>
();
const
int
*
ids_data
=
in0
->
data
<
int
>
();
const
int
*
labels_data
=
in1
->
data
<
int
>
();
size_t
cls_num
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"class_number"
));
const
T
*
weights_data
=
in2
?
in2
->
data
<
T
>
()
:
nullptr
;
const
T
*
states_data
=
in3
?
in3
->
data
<
T
>
()
:
nullptr
;
double
*
batch_metrics_data
=
out0
->
mutable_data
<
double
>
(
ctx
.
GetPlace
());
...
...
@@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T
*
accum_states_data
=
out2
->
data
<
T
>
();
size_t
sample_num
=
in0
->
dims
()[
0
];
size_t
class_dim
=
in0
->
dims
()[
1
];
size_t
state_var_num
=
4
;
// TP FP TN FN
// get states info for current batch
for
(
size_t
i
=
0
;
i
<
sample_num
;
++
i
)
{
size_t
max_idx
=
0
;
T
max_val
=
predictions_data
[
i
*
class_dim
];
for
(
size_t
j
=
1
;
j
<
class_dim
;
++
j
)
{
if
(
max_val
<
predictions_data
[
i
*
class_dim
+
j
])
{
max_idx
=
j
;
max_val
=
predictions_data
[
i
*
class_dim
+
j
]
;
}
}
size_t
idx
=
ids_data
[
i
]
;
size_t
label
=
labels_data
[
i
];
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
cls_num
,
"Class index of each instance should be in "
"[0, class_number)."
)
;
PADDLE_ENFORCE
(
label
>=
0
&&
label
<
cls_num
,
"Label of each instance should be in [0, class_number)."
);
T
w
=
weights_data
?
weights_data
[
i
]
:
1.0
;
if
(
max_idx
==
labels_data
[
i
]
)
{
accum_states_data
[
max_
idx
*
state_var_num
+
TP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cl
ass_di
m
;
++
j
)
{
if
(
idx
==
label
)
{
accum_states_data
[
idx
*
state_var_num
+
TP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cl
s_nu
m
;
++
j
)
{
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
}
accum_states_data
[
max_
idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
idx
*
state_var_num
+
TN
]
-=
w
;
}
else
{
accum_states_data
[
label
s_data
[
i
]
*
state_var_num
+
FN
]
+=
w
;
accum_states_data
[
max_
idx
*
state_var_num
+
FP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cl
ass_di
m
;
++
j
)
{
accum_states_data
[
label
*
state_var_num
+
FN
]
+=
w
;
accum_states_data
[
idx
*
state_var_num
+
FP
]
+=
w
;
for
(
size_t
j
=
0
;
j
<
cl
s_nu
m
;
++
j
)
{
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
}
accum_states_data
[
max_
idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
label
s_data
[
i
]
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
label
*
state_var_num
+
TN
]
-=
w
;
}
}
ComputeMetrics
(
accum_states_data
,
batch_metrics_data
,
state_var_num
,
cl
ass_di
m
);
cl
s_nu
m
);
if
(
states_data
)
{
for
(
size_t
i
=
0
;
i
<
cl
ass_di
m
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
cl
s_nu
m
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
state_var_num
;
++
j
)
{
size_t
idx
=
i
*
state_var_num
+
j
;
accum_states_data
[
idx
]
+=
states_data
[
idx
];
...
...
@@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
}
ComputeMetrics
(
accum_states_data
,
accum_metrics_data
,
state_var_num
,
cl
ass_di
m
);
cl
s_nu
m
);
}
// expose to be reused
...
...
@@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected:
void
ComputeMetrics
(
const
T
*
states_data
,
double
*
metrics_data
,
size_t
state_var_num
,
size_t
cl
ass_di
m
)
const
{
size_t
state_var_num
,
size_t
cl
s_nu
m
)
const
{
T
total_tp_count
=
0
;
T
total_fp_count
=
0
;
T
total_fn_count
=
0
;
T
macro_avg_precision
=
0.0
;
T
macro_avg_recall
=
0.0
;
for
(
size_t
i
=
0
;
i
<
cl
ass_di
m
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
cl
s_nu
m
;
++
i
)
{
T
tp_count
=
states_data
[
i
*
state_var_num
+
TP
];
T
fp_count
=
states_data
[
i
*
state_var_num
+
FP
];
T
fn_count
=
states_data
[
i
*
state_var_num
+
FN
];
...
...
@@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
macro_avg_precision
+=
CalcPrecision
(
tp_count
,
fp_count
);
macro_avg_recall
+=
CalcRecall
(
tp_count
,
fn_count
);
}
macro_avg_precision
/=
cl
ass_di
m
;
macro_avg_recall
/=
cl
ass_di
m
;
macro_avg_precision
/=
cl
s_nu
m
;
macro_avg_recall
/=
cl
s_nu
m
;
T
macro_f1_score
=
CalcF1Score
(
macro_avg_precision
,
macro_avg_recall
);
T
micro_avg_precision
=
CalcPrecision
(
total_tp_count
,
total_fp_count
);
...
...
python/paddle/v2/framework/tests/test_precision_recall_op.py
浏览文件 @
970613fc
...
...
@@ -21,45 +21,44 @@ def calc_f1_score(precision, recall):
return
0.0
def
get_states
(
predictions
,
labels
,
weights
=
None
):
ins_num
=
predictions
.
shape
[
0
]
class_num
=
predictions
.
shape
[
1
]
def
get_states
(
idxs
,
labels
,
cls_num
,
weights
=
None
):
ins_num
=
idxs
.
shape
[
0
]
# TP FP TN FN
states
=
np
.
zeros
((
cl
as
s_num
,
4
)).
astype
(
'float32'
)
states
=
np
.
zeros
((
cls_num
,
4
)).
astype
(
'float32'
)
for
i
in
xrange
(
ins_num
):
w
=
weights
[
i
]
if
weights
is
not
None
else
1.0
max_idx
=
np
.
argmax
(
predictions
[
i
])
if
max_idx
==
labels
[
i
][
0
]:
states
[
max_idx
][
0
]
+=
w
for
j
in
xrange
(
class_num
):
idx
=
idxs
[
i
][
0
]
label
=
labels
[
i
][
0
]
if
idx
==
label
:
states
[
idx
][
0
]
+=
w
for
j
in
xrange
(
cls_num
):
states
[
j
][
2
]
+=
w
states
[
max_
idx
][
2
]
-=
w
states
[
idx
][
2
]
-=
w
else
:
states
[
label
s
[
i
][
0
]
][
3
]
+=
w
states
[
max_
idx
][
1
]
+=
w
for
j
in
xrange
(
cl
as
s_num
):
states
[
label
][
3
]
+=
w
states
[
idx
][
1
]
+=
w
for
j
in
xrange
(
cls_num
):
states
[
j
][
2
]
+=
w
states
[
label
s
[
i
][
0
]
][
2
]
-=
w
states
[
max_
idx
][
2
]
-=
w
states
[
label
][
2
]
-=
w
states
[
idx
][
2
]
-=
w
return
states
def
compute_metrics
(
states
):
class_num
=
states
.
shape
[
0
]
def
compute_metrics
(
states
,
cls_num
):
total_tp_count
=
0.0
total_fp_count
=
0.0
total_fn_count
=
0.0
macro_avg_precision
=
0.0
macro_avg_recall
=
0.0
for
i
in
xrange
(
cl
as
s_num
):
for
i
in
xrange
(
cls_num
):
total_tp_count
+=
states
[
i
][
0
]
total_fp_count
+=
states
[
i
][
1
]
total_fn_count
+=
states
[
i
][
3
]
macro_avg_precision
+=
calc_precision
(
states
[
i
][
0
],
states
[
i
][
1
])
macro_avg_recall
+=
calc_recall
(
states
[
i
][
0
],
states
[
i
][
3
])
metrics
=
[]
macro_avg_precision
/=
cl
as
s_num
macro_avg_recall
/=
cl
as
s_num
macro_avg_precision
/=
cls_num
macro_avg_recall
/=
cls_num
metrics
.
append
(
macro_avg_precision
)
metrics
.
append
(
macro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
macro_avg_precision
,
macro_avg_recall
))
...
...
@@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
predictions
,
labels
)
metrics
=
compute_metrics
(
states
)
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
idxs
,
labels
,
cls_num
)
metrics
=
compute_metrics
(
states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'
Predictions'
:
prediction
s
,
'Labels'
:
labels
}
self
.
inputs
=
{
'
MaxProbs'
:
max_probs
,
'Indices'
:
idx
s
,
'Labels'
:
labels
}
self
.
outputs
=
{
'BatchMetrics'
:
metrics
,
...
...
@@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
predictions
=
np
.
random
.
random
((
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
predictions
,
labels
,
weights
)
metrics
=
compute_metrics
(
states
)
states
=
get_states
(
idxs
,
labels
,
cls_num
,
weights
)
metrics
=
compute_metrics
(
states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'Predictions'
:
predictions
,
'MaxProbs'
:
max_probs
,
'Indices'
:
idxs
,
'Labels'
:
labels
,
'Weights'
:
weights
}
...
...
@@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
cls_num
=
10
max_probs
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
idxs
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
predictions
=
np
.
random
.
random
((
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
labels
=
np
.
random
.
choice
(
xrange
(
cls_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
np
.
random
.
randint
(
0
,
30
,
(
cl
as
s_num
,
4
)).
astype
(
'float32'
)
states
=
np
.
random
.
randint
(
0
,
30
,
(
cls_num
,
4
)).
astype
(
'float32'
)
accum_states
=
get_states
(
predictions
,
labels
,
weights
)
batch_metrics
=
compute_metrics
(
accum_states
)
accum_states
=
get_states
(
idxs
,
labels
,
cls_num
,
weights
)
batch_metrics
=
compute_metrics
(
accum_states
,
cls_num
)
accum_states
+=
states
accum_metrics
=
compute_metrics
(
accum_states
)
accum_metrics
=
compute_metrics
(
accum_states
,
cls_num
)
self
.
attrs
=
{
'class_number'
:
cls_num
}
self
.
inputs
=
{
'Predictions'
:
predictions
,
'MaxProbs'
:
max_probs
,
'Indices'
:
idxs
,
'Labels'
:
labels
,
'Weights'
:
weights
,
'StatesInfo'
:
states
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录