Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
88600d7d
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
88600d7d
编写于
6月 17, 2020
作者:
C
Chengmo
提交者:
GitHub
6月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix lod tensor return (#101)
Co-authored-by:
N
tangwei
<
tangwei12@baidu.com
>
上级
5a521800
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
18 deletion
+49
-18
core/trainers/framework/runner.py
core/trainers/framework/runner.py
+48
-12
core/trainers/general_trainer.py
core/trainers/general_trainer.py
+1
-6
未找到文件。
core/trainers/framework/runner.py
浏览文件 @
88600d7d
...
...
@@ -16,10 +16,9 @@ from __future__ import print_function
import
os
import
time
import
warnings
import
datetime
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddlerec.core.utils
import
envs
__all__
=
[
...
...
@@ -27,6 +26,42 @@ __all__ = [
]
def
as_numpy
(
tensor
):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
new_scope = fluid.Scope()
with fluid.scope_guard(new_scope):
fluid.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), fluid.CPUPlace())
tensor = new_scope.find_var("data").get_tensor()
fluid.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if
isinstance
(
tensor
,
fluid
.
core
.
LoDTensorArray
):
return
[
as_numpy
(
t
)
for
t
in
tensor
]
if
isinstance
(
tensor
,
list
):
return
[
as_numpy
(
t
)
for
t
in
tensor
]
assert
isinstance
(
tensor
,
fluid
.
core
.
LoDTensor
)
lod
=
tensor
.
lod
()
# (todo) need print lod or return it for user
if
tensor
.
_is_initialized
():
return
np
.
array
(
tensor
)
else
:
return
None
class
RunnerBase
(
object
):
"""R
"""
...
...
@@ -92,9 +127,6 @@ class RunnerBase(object):
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
program
=
self
.
_get_dataloader_program
(
model_dict
,
context
)
reader_name
=
model_dict
[
"dataset_name"
]
fetch_vars
=
[]
fetch_alias
=
[]
fetch_period
=
int
(
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
".print_interval"
,
20
))
...
...
@@ -103,9 +135,6 @@ class RunnerBase(object):
else
:
metrics
=
model_class
.
get_metrics
()
if
metrics
:
fetch_vars
=
metrics
.
values
()
fetch_alias
=
metrics
.
keys
()
metrics_varnames
=
[]
metrics_format
=
[]
metrics_format
.
append
(
"{}: {{}}"
.
format
(
"batch"
))
...
...
@@ -121,9 +150,16 @@ class RunnerBase(object):
with
fluid
.
scope_guard
(
scope
):
try
:
while
True
:
metrics_rets
=
context
[
"exe"
].
run
(
program
=
program
,
fetch_list
=
metrics_varnames
)
metrics_tensors
=
context
[
"exe"
].
run
(
program
=
program
,
fetch_list
=
metrics_varnames
,
return_numpy
=
False
)
metrics
=
[
batch_id
]
metrics_rets
=
[
as_numpy
(
metrics_tensor
)
for
metrics_tensor
in
metrics_tensors
]
metrics
.
extend
(
metrics_rets
)
if
batch_id
%
fetch_period
==
0
and
batch_id
!=
0
:
...
...
@@ -248,7 +284,7 @@ class RunnerBase(object):
fetch_varnames
=
envs
.
get_global_env
(
name
+
"save_inference_fetch_varnames"
,
[])
if
feed_varnames
is
None
or
fetch_varnames
is
None
or
feed_varnames
==
""
or
fetch_varnames
==
""
or
\
len
(
feed_varnames
)
==
0
or
len
(
fetch_varnames
)
==
0
:
len
(
feed_varnames
)
==
0
or
len
(
fetch_varnames
)
==
0
:
return
fetch_vars
=
[
fluid
.
default_main_program
().
global_block
().
vars
[
varname
]
...
...
core/trainers/general_trainer.py
浏览文件 @
88600d7d
...
...
@@ -19,12 +19,7 @@ from __future__ import print_function
import
os
from
paddlerec.core.utils
import
envs
from
paddlerec.core.trainer
import
Trainer
,
EngineMode
,
FleetMode
,
Device
from
paddlerec.core.trainers.framework.dataset
import
*
from
paddlerec.core.trainers.framework.runner
import
*
from
paddlerec.core.trainers.framework.instance
import
*
from
paddlerec.core.trainers.framework.network
import
*
from
paddlerec.core.trainers.framework.startup
import
*
from
paddlerec.core.trainer
import
Trainer
,
EngineMode
,
FleetMode
class
GeneralTrainer
(
Trainer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录