Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
79a2ce42
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79a2ce42
编写于
11月 07, 2017
作者:
D
Dong Zhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add small evaluation"
上级
7874399c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
22 addition
and
26 deletion
+22
-26
python/paddle/v2/framework/evaluator.py
python/paddle/v2/framework/evaluator.py
+22
-26
未找到文件。
python/paddle/v2/framework/evaluator.py
浏览文件 @
79a2ce42
from
paddle.v2.framework.framework
import
Program
,
g_program
,
unique_name
from
paddle.v2.framework.framework
import
Program
,
g_
main_
program
,
unique_name
from
paddle.v2.framework.layer_helper
import
LayerHelper
from
paddle.v2.framework.layer_helper
import
LayerHelper
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
...
@@ -14,17 +14,10 @@ class Evaluator(object):
...
@@ -14,17 +14,10 @@ class Evaluator(object):
def
__init__
(
self
,
name
,
**
kwargs
):
def
__init__
(
self
,
name
,
**
kwargs
):
self
.
_states
=
{}
self
.
_states
=
{}
self
.
_helper
=
LayerHelper
(
layer_type
=
name
,
**
kwargs
)
if
kwargs
.
has_key
(
"program"
):
# if kwargs.has_key("program"):
self
.
_program
=
kwargs
.
get
(
"program"
)
# self._program = kwargs.get("program")
else
:
# else:
self
.
_program
=
g_main_program
# self._program = g_program
# def _update(self):
# """
# Updates the internal states througth operator
# """
# raise NotImplementedError()
def
reset
(
self
,
executor
,
program
=
None
):
def
reset
(
self
,
executor
,
program
=
None
):
"""
"""
...
@@ -34,20 +27,21 @@ class Evaluator(object):
...
@@ -34,20 +27,21 @@ class Evaluator(object):
reset_program
=
Program
()
reset_program
=
Program
()
else
:
else
:
reset_program
=
program
reset_program
=
program
block
=
reset_program
.
global_block
()
for
k
,
var
in
self
.
_states
.
iteritems
():
for
k
,
var
in
self
.
_states
.
iteritems
():
zeros
=
helper
.
create_tmp_variable
(
dtype
=
var
.
data_type
)
zeros
=
block
.
create_var
(
dtype
=
var
.
data_type
)
self
.
_helper
.
append_op
(
block
.
append_op
(
type
=
"fill_constant"
,
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
[
zeros
]},
outputs
=
{
"Out"
:
[
zeros
]},
attrs
=
{
attrs
=
{
"shape"
:
var
.
shape
,
"shape"
:
var
.
shape
,
"value"
:
0
,
"value"
:
0
,
})
})
self
.
_helper
.
append_op
(
block
.
append_op
(
type
=
"scale"
,
inputs
=
{
"X"
:
zeros
},
outputs
=
{
"Out"
:
var
})
type
=
"scale"
,
inputs
=
{
"X"
:
zeros
},
outputs
=
{
"Out"
:
var
})
executor
.
run
(
reset_program
)
executor
.
run
(
reset_program
)
def
eval
(
self
):
def
eval
(
self
,
executor
,
program
=
None
):
"""
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
"""
"""
...
@@ -61,7 +55,8 @@ class Accuracy(Evaluator):
...
@@ -61,7 +55,8 @@ class Accuracy(Evaluator):
def
__init__
(
self
,
input
,
label
,
k
=
1
,
**
kwargs
):
def
__init__
(
self
,
input
,
label
,
k
=
1
,
**
kwargs
):
super
(
Accuracy
,
self
).
__init__
(
"accuracy"
,
**
kwargs
)
super
(
Accuracy
,
self
).
__init__
(
"accuracy"
,
**
kwargs
)
g_total
=
helper
.
create_global_variable
(
block
=
self
.
_program
.
global_block
()
g_total
=
block
.
create_var
(
name
=
unique_name
(
"Total"
),
name
=
unique_name
(
"Total"
),
persistable
=
True
,
persistable
=
True
,
dtype
=
"int64"
,
dtype
=
"int64"
,
...
@@ -74,17 +69,17 @@ class Accuracy(Evaluator):
...
@@ -74,17 +69,17 @@ class Accuracy(Evaluator):
self
.
_states
[
"Total"
]
=
g_total
self
.
_states
[
"Total"
]
=
g_total
self
.
_states
[
"Correct"
]
=
g_correct
self
.
_states
[
"Correct"
]
=
g_correct
topk_out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
data_type
)
topk_out
=
block
.
create_var
(
dtype
=
input
.
data_type
)
topk_indices
=
helper
.
create_tmp_variable
(
dtype
=
"int64"
)
topk_indices
=
block
.
create_var
(
dtype
=
"int64"
)
helper
.
append_op
(
block
.
append_op
(
type
=
"top_k"
,
type
=
"top_k"
,
inputs
=
{
"X"
:
[
input
]},
inputs
=
{
"X"
:
[
input
]},
outputs
=
{
"Out"
:
[
topk_out
],
outputs
=
{
"Out"
:
[
topk_out
],
"Indices"
:
[
topk_indices
]},
"Indices"
:
[
topk_indices
]},
attrs
=
{
"k"
:
k
})
attrs
=
{
"k"
:
k
})
acc_out_dtype
=
kwargs
.
get
(
"out_dtype"
,
"float32"
)
acc_out_dtype
=
kwargs
.
get
(
"out_dtype"
,
"float32"
)
acc_out
=
helper
.
create_tmp_variable
(
dtype
=
acc_out_dtype
)
acc_out
=
block
.
create_var
(
dtype
=
acc_out_dtype
)
helper
.
append_op
(
block
.
append_op
(
type
=
"accuracy"
,
type
=
"accuracy"
,
inputs
=
{
inputs
=
{
"Out"
:
[
topk_out
],
"Out"
:
[
topk_out
],
...
@@ -97,11 +92,11 @@ class Accuracy(Evaluator):
...
@@ -97,11 +92,11 @@ class Accuracy(Evaluator):
"Total"
:
[
total
],
"Total"
:
[
total
],
})
})
helper
.
append_op
(
block
.
append_op
(
type
=
"sum"
,
type
=
"sum"
,
inputs
=
{
"X"
:
[
g_total
,
total
]},
inputs
=
{
"X"
:
[
g_total
,
total
]},
outputs
=
{
"Out"
:
[
g_total
]})
outputs
=
{
"Out"
:
[
g_total
]})
helper
.
append_op
(
block
.
append_op
(
type
=
"sum"
,
type
=
"sum"
,
inputs
=
{
"X"
:
[
g_correct
,
correct
]},
inputs
=
{
"X"
:
[
g_correct
,
correct
]},
outputs
=
{
"Out"
:
[
g_total
]})
outputs
=
{
"Out"
:
[
g_total
]})
...
@@ -112,8 +107,9 @@ class Accuracy(Evaluator):
...
@@ -112,8 +107,9 @@ class Accuracy(Evaluator):
eval_program
=
Program
()
eval_program
=
Program
()
else
:
else
:
eval_program
=
program
eval_program
=
program
eval_out
=
helper
.
create_tmp_variable
(
dtype
=
self
.
_helper
.
input_dtype
())
block
=
eval_program
.
global_block
()
self
.
_helper
.
append_op
(
eval_out
=
block
.
create_var
(
dtype
=
self
.
_helper
.
input_dtype
())
block
.
append_op
(
type
=
"elementwise_div"
,
type
=
"elementwise_div"
,
inputs
=
{
"X"
:
self
.
_states
[
"Total"
],
inputs
=
{
"X"
:
self
.
_states
[
"Total"
],
"Y"
:
self
.
_states
[
"Correct"
]},
"Y"
:
self
.
_states
[
"Correct"
]},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录