Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c4ac7fab
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c4ac7fab
编写于
11月 07, 2017
作者:
D
Dong Zhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
'add f1 test'
上级
8d9b3341
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
12 addition
and
13 deletion
+12
-13
python/paddle/v2/framework/evaluator.py
python/paddle/v2/framework/evaluator.py
+8
-12
python/paddle/v2/framework/tests/test_fit_a_line.py
python/paddle/v2/framework/tests/test_fit_a_line.py
+4
-1
未找到文件。
python/paddle/v2/framework/evaluator.py
浏览文件 @
c4ac7fab
...
...
@@ -121,18 +121,14 @@ class Accuracy(Evaluator):
return
executor
.
run
(
eval_program
,
fetch_list
=
[
eval_out
])
#
This is demo for composing low level op to compute
metric
#
Demo for composing low level op to compute the F1
metric
class
F1
(
Evaluator
):
def
__init__
(
self
,
input
,
label
,
**
kwargs
):
super
(
F1
,
self
).
__init__
(
"F1"
,
**
kwargs
)
super
(
Accuracy
,
self
).
__init__
(
"accuracy"
,
**
kwargs
)
g_total
=
helper
.
create_global_variable
(
name
=
unique_name
(
"Total"
),
persistable
=
True
,
dtype
=
"int64"
,
shape
=
[
1
])
g_correct
=
helper
.
create_global_variable
(
name
=
unique_name
(
"Correct"
),
persistable
=
True
,
dtype
=
"int64"
,
shape
=
[
1
])
g_tp
=
helper
.
create_global_variable
(
name
=
unique_name
(
"Tp"
),
persistable
=
True
,
dtype
=
"int64"
,
shape
=
[
1
])
g_fp
=
helper
.
create_global_variable
(
name
=
unique_name
(
"Fp"
),
persistable
=
True
,
dtype
=
"int64"
,
shape
=
[
1
])
self
.
_states
[
"Tp"
]
=
g_tp
self
.
_states
[
"Fp"
]
=
g_fp
python/paddle/v2/framework/tests/test_fit_a_line.py
浏览文件 @
c4ac7fab
...
...
@@ -61,6 +61,7 @@ PASS_NUM = 100
for
pass_id
in
range
(
PASS_NUM
):
save_persistables
(
exe
,
"./fit_a_line.model/"
,
main_program
=
main_program
)
load_persistables
(
exe
,
"./fit_a_line.model/"
,
main_program
=
main_program
)
accuracy
.
reset
(
exe
)
for
data
in
train_reader
():
x_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"float32"
)
...
...
@@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM):
outs
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
tensor_x
,
'y'
:
tensor_y
},
fetch_list
=
[
avg_cost
])
fetch_list
=
[
avg_cost
,
accuracy
])
out
=
np
.
array
(
outs
[
0
])
pass_acc
=
accuracy
.
eval
(
exe
)
print
pass_acc
if
out
[
0
]
<
10.0
:
exit
(
0
)
# if avg cost less than 10.0, we think our code is good.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录