Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
65dbbd57
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65dbbd57
编写于
10月 26, 2017
作者:
Y
yangyaming
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add and pass unittests.
上级
06c7c8c8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
188 addition
and
11 deletion
+188
-11
paddle/operators/precision_recall_op.cc
paddle/operators/precision_recall_op.cc
+16
-5
paddle/operators/precision_recall_op.h
paddle/operators/precision_recall_op.h
+8
-6
python/paddle/v2/framework/tests/test_precision_recall_op.py
python/paddle/v2/framework/tests/test_precision_recall_op.py
+164
-0
未找到文件。
paddle/operators/precision_recall_op.cc
浏览文件 @
65dbbd57
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/precision_recall_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
...
@@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
if
(
ctx
->
HasInput
(
"Weights"
))
{
if
(
ctx
->
HasInput
(
"Weights"
))
{
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
PADDLE_ENFORCE_EQ
(
weights_dims
,
{
predictions_dims
[
0
],
1
},
PADDLE_ENFORCE_EQ
(
weights_dims
,
framework
::
make_ddim
({
predictions_dims
[
0
],
1
}),
"The shape of Input(Weights) should be "
"The shape of Input(Weights) should be "
"[batch_size, 1]."
);
"[batch_size, 1]."
);
}
}
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
PADDLE_ENFORCE_EQ
(
states_dims
,
{
predictions_dims
[
1
],
4
},
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
predictions_dims
[
1
],
4
}),
"The shape of Input(StatesInfo) should be "
"The shape of Input(StatesInfo) should be "
"[class_number, 4]."
);
"[class_number, 4]."
);
}
}
...
@@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
...
@@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// [ TP, FP, TN, FN ]
// [ TP, FP, TN, FN ]
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
predictions_dims
[
1
],
4
});
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
predictions_dims
[
1
],
4
});
}
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Predictions"
)
->
type
());
}
};
};
class
PrecisionRecallOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
PrecisionRecallOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"provided, current state will be accumulated to this state and "
"provided, current state will be accumulated to this state and "
"the accumulation state will be as the output state."
)
"the accumulation state will be as the output state."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"BatchMetrics"
,
""
);
AddOutput
(
"AccumMetrics"
,
""
);
AddOutput
(
"AccumStatesInfo"
,
""
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
)DOC"
);
)DOC"
);
...
@@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
...
@@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
precision_recall
,
precision_recall
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
int
>
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
,
ops
::
PrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
,
paddle/operators/precision_recall_op.h
浏览文件 @
65dbbd57
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto
*
out2
=
ctx
.
Output
<
Tensor
>
(
"AccumStatesInfo"
);
auto
*
out2
=
ctx
.
Output
<
Tensor
>
(
"AccumStatesInfo"
);
const
T
*
predictions_data
=
in0
->
data
<
T
>
();
const
T
*
predictions_data
=
in0
->
data
<
T
>
();
const
T
*
labels_data
=
in1
->
data
<
T
>
();
const
int
*
labels_data
=
in1
->
data
<
int
>
();
const
T
*
weights_data
=
in2
?
in2
->
data
<
T
>
()
:
nullptr
;
const
T
*
weights_data
=
in2
?
in2
->
data
<
T
>
()
:
nullptr
;
const
T
*
states_data
=
in3
?
in3
->
data
<
T
>
()
:
nullptr
;
const
T
*
states_data
=
in3
?
in3
->
data
<
T
>
()
:
nullptr
;
T
*
batch_metrics_data
=
out0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batch_metrics_data
=
out0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
out2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
accum_states
=
EigenMatrix
<
T
>::
From
(
*
out2
);
auto
accum_states
=
EigenMatrix
<
T
>::
From
(
*
out2
);
accum_states
.
setZero
();
accum_states
.
setZero
();
T
*
accum_states_data
=
out2
->
data
<
T
>
(
ctx
.
GetPlace
()
);
T
*
accum_states_data
=
out2
->
data
<
T
>
();
size_t
sample_num
=
in0
->
dims
()[
0
];
size_t
sample_num
=
in0
->
dims
()[
0
];
size_t
class_dim
=
in0
->
dims
()[
1
];
size_t
class_dim
=
in0
->
dims
()[
1
];
...
@@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
accum_states_data
[
j
*
state_var_num
+
TN
]
+=
w
;
}
}
accum_states_data
[
max_idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
max_idx
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
labels_data
[
j
]
*
state_var_num
+
TN
]
-=
w
;
accum_states_data
[
labels_data
[
i
]
*
state_var_num
+
TN
]
-=
w
;
}
}
}
}
...
@@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
if
(
tp_count
>
0.0
||
fn_count
>
0.0
)
{
if
(
tp_count
>
0.0
||
fn_count
>
0.0
)
{
return
tp_count
/
(
tp_count
+
fn_count
);
return
tp_count
/
(
tp_count
+
fn_count
);
}
}
return
1.0
return
1.0
;
}
}
static
inline
T
CalcF1Score
(
T
precision
,
T
recall
)
{
static
inline
T
CalcF1Score
(
T
precision
,
T
recall
)
{
...
@@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected:
protected:
void
ComputeMetrics
(
const
T
*
states_data
,
T
*
metrics_data
,
void
ComputeMetrics
(
const
T
*
states_data
,
T
*
metrics_data
,
size_t
state_var_num
,
size_t
class_dim
)
{
size_t
state_var_num
,
size_t
class_dim
)
const
{
T
total_tp_count
=
0
;
T
total_tp_count
=
0
;
T
total_fp_count
=
0
;
T
total_fp_count
=
0
;
T
total_fn_count
=
0
;
T
total_fn_count
=
0
;
...
@@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
...
@@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T
micro_avg_precision
=
CalcPrecision
(
total_tp_count
,
total_fp_count
);
T
micro_avg_precision
=
CalcPrecision
(
total_tp_count
,
total_fp_count
);
T
micro_avg_recall
=
CalcRecall
(
total_tp_count
,
total_fn_count
);
T
micro_avg_recall
=
CalcRecall
(
total_tp_count
,
total_fn_count
);
T
micro_f1_score
=
Calc
Recall
(
micro_avg_precision
,
micro_avg_recall
);
T
micro_f1_score
=
Calc
F1Score
(
micro_avg_precision
,
micro_avg_recall
);
// fill metrics data
// fill metrics data
metrics_data
[
0
]
=
macro_avg_precision
;
metrics_data
[
0
]
=
macro_avg_precision
;
...
...
python/paddle/v2/framework/tests/test_precision_recall_op.py
0 → 100644
浏览文件 @
65dbbd57
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
calc_precision
(
tp_count
,
fp_count
):
if
tp_count
>
0.0
or
fp_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fp_count
)
return
1.0
def
calc_recall
(
tp_count
,
fn_count
):
if
tp_count
>
0.0
or
fn_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fn_count
)
return
1.0
def
calc_f1_score
(
precision
,
recall
):
if
precision
>
0.0
or
recall
>
0.0
:
return
2
*
precision
*
recall
/
(
precision
+
recall
)
return
0.0
def
get_states
(
predictions
,
labels
,
weights
=
None
):
ins_num
=
predictions
.
shape
[
0
]
class_num
=
predictions
.
shape
[
1
]
# TP FP TN FN
states
=
np
.
zeros
((
class_num
,
4
)).
astype
(
'float32'
)
for
i
in
xrange
(
ins_num
):
w
=
weights
[
i
]
if
weights
is
not
None
else
1.0
max_idx
=
np
.
argmax
(
predictions
[
i
])
if
max_idx
==
labels
[
i
][
0
]:
states
[
max_idx
][
0
]
+=
w
for
j
in
xrange
(
class_num
):
states
[
j
][
2
]
+=
w
states
[
max_idx
][
2
]
-=
w
else
:
states
[
labels
[
i
][
0
]][
3
]
+=
w
states
[
max_idx
][
1
]
+=
w
for
j
in
xrange
(
class_num
):
states
[
j
][
2
]
+=
w
states
[
labels
[
i
][
0
]][
2
]
-=
w
states
[
max_idx
][
2
]
-=
w
return
states
def
compute_metrics
(
states
):
class_num
=
states
.
shape
[
0
]
total_tp_count
=
0.0
total_fp_count
=
0.0
total_fn_count
=
0.0
macro_avg_precision
=
0.0
macro_avg_recall
=
0.0
for
i
in
xrange
(
class_num
):
total_tp_count
+=
states
[
i
][
0
]
total_fp_count
+=
states
[
i
][
1
]
total_fn_count
+=
states
[
i
][
3
]
macro_avg_precision
+=
calc_precision
(
states
[
i
][
0
],
states
[
i
][
1
])
macro_avg_recall
+=
calc_recall
(
states
[
i
][
0
],
states
[
i
][
3
])
metrics
=
[]
macro_avg_precision
/=
class_num
macro_avg_recall
/=
class_num
metrics
.
append
(
macro_avg_precision
)
metrics
.
append
(
macro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
macro_avg_precision
,
macro_avg_recall
))
micro_avg_precision
=
calc_precision
(
total_tp_count
,
total_fp_count
)
metrics
.
append
(
micro_avg_precision
)
micro_avg_recall
=
calc_recall
(
total_tp_count
,
total_fn_count
)
metrics
.
append
(
micro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
micro_avg_precision
,
micro_avg_recall
))
return
np
.
array
(
metrics
).
astype
(
'float32'
)
class
TestPrecisionRecallOp_0
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
predictions
,
labels
)
metrics
=
compute_metrics
(
states
)
self
.
inputs
=
{
'Predictions'
:
predictions
,
'Labels'
:
labels
}
self
.
outputs
=
{
'BatchMetrics'
:
metrics
,
'AccumMetrics'
:
metrics
,
'AccumStatesInfo'
:
states
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestPrecisionRecallOp_1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
predictions
=
np
.
random
.
random
((
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
get_states
(
predictions
,
labels
,
weights
)
metrics
=
compute_metrics
(
states
)
self
.
inputs
=
{
'Predictions'
:
predictions
,
'Labels'
:
labels
,
'Weights'
:
weights
}
self
.
outputs
=
{
'BatchMetrics'
:
metrics
,
'AccumMetrics'
:
metrics
,
'AccumStatesInfo'
:
states
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestPrecisionRecallOp_2
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"precision_recall"
ins_num
=
64
class_num
=
10
predictions
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
class_num
)).
astype
(
'float32'
)
weights
=
np
.
random
.
uniform
(
0
,
1.0
,
(
ins_num
,
1
)).
astype
(
'float32'
)
predictions
=
np
.
random
.
random
((
ins_num
,
class_num
)).
astype
(
'float32'
)
labels
=
np
.
random
.
choice
(
xrange
(
class_num
),
ins_num
).
reshape
(
(
ins_num
,
1
)).
astype
(
'int32'
)
states
=
np
.
random
.
randint
(
0
,
30
,
(
class_num
,
4
)).
astype
(
'float32'
)
accum_states
=
get_states
(
predictions
,
labels
,
weights
)
batch_metrics
=
compute_metrics
(
accum_states
)
accum_states
+=
states
accum_metrics
=
compute_metrics
(
accum_states
)
self
.
inputs
=
{
'Predictions'
:
predictions
,
'Labels'
:
labels
,
'Weights'
:
weights
,
'StatesInfo'
:
states
}
self
.
outputs
=
{
'BatchMetrics'
:
batch_metrics
,
'AccumMetrics'
:
accum_metrics
,
'AccumStatesInfo'
:
accum_states
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录