Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
a5d3f512
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5d3f512
编写于
6月 08, 2020
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix code style
上级
505d7e7c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
27 addition
and
26 deletion
+27
-26
core/trainer.py
core/trainer.py
+3
-3
core/utils/dataset_holder.py
core/utils/dataset_holder.py
+12
-12
core/utils/envs.py
core/utils/envs.py
+6
-7
core/utils/util.py
core/utils/util.py
+6
-4
未找到文件。
core/trainer.py
浏览文件 @
a5d3f512
...
...
@@ -62,7 +62,8 @@ class Trainer(object):
traceback
.
print_exc
()
print
(
'Catch Exception:%s'
%
str
(
err
))
sys
.
stdout
.
flush
()
self
.
_context
[
'is_exit'
]
=
self
.
handle_processor_exception
(
status
,
context
,
err
)
self
.
_context
[
'is_exit'
]
=
self
.
handle_processor_exception
(
status
,
context
,
err
)
def
other_status_processor
(
self
,
context
):
"""
...
...
@@ -72,7 +73,7 @@ class Trainer(object):
"""
print
(
'unknow context_status:%s, do nothing'
%
context
[
'status'
])
time
.
sleep
(
60
)
def
handle_processor_exception
(
self
,
status
,
context
,
exception
):
"""
when exception throwed from processor, will call this func to handle it
...
...
@@ -82,7 +83,6 @@ class Trainer(object):
print
(
'Exit app. catch exception in precoss status:%s, except:%s'
\
%
(
context
[
'status'
],
str
(
exception
)))
return
True
def
reload_train_context
(
self
):
"""
...
...
core/utils/dataset_holder.py
浏览文件 @
a5d3f512
...
...
@@ -66,7 +66,6 @@ class TimeSplitDatasetHolder(DatasetHolder):
"""
Dataset with time split dir. root_path/$DAY/$HOUR
"""
def
__init__
(
self
,
config
):
"""
init data root_path, time_split_interval, data_path_format
...
...
@@ -113,8 +112,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
True/False
"""
is_ready
=
True
data_time
,
windows_mins
=
self
.
_format_data_time
(
daytime_str
,
time_window_mins
)
data_time
,
windows_mins
=
self
.
_format_data_time
(
daytime_str
,
time_window_mins
)
while
time_window_mins
>
0
:
file_path
=
self
.
_path_generator
.
generate_path
(
'donefile_path'
,
{
'time_format'
:
data_time
})
...
...
@@ -142,18 +141,19 @@ class TimeSplitDatasetHolder(DatasetHolder):
list, data_shard[node_idx]
"""
data_file_list
=
[]
data_time
,
windows_mins
=
self
.
_format_data_time
(
daytime_str
,
time_window_mins
)
data_time
,
windows_mins
=
self
.
_format_data_time
(
daytime_str
,
time_window_mins
)
while
time_window_mins
>
0
:
file_path
=
self
.
_path_generator
.
generate_path
(
'data_path'
,
{
'time_format'
:
data_time
})
sub_file_list
=
self
.
_data_file_handler
.
ls
(
file_path
)
for
sub_file
in
sub_file_list
:
sub_file_name
=
self
.
_data_file_handler
.
get_file_name
(
sub_file
)
if
not
sub_file_name
.
startswith
(
self
.
_config
[
'filename_prefix'
]):
if
not
sub_file_name
.
startswith
(
self
.
_config
[
'filename_prefix'
]):
continue
postfix
=
sub_file_name
.
split
(
self
.
_config
[
'filename_prefix'
])[
1
]
postfix
=
sub_file_name
.
split
(
self
.
_config
[
'filename_prefix'
])[
1
]
if
postfix
.
isdigit
():
if
int
(
postfix
)
%
node_num
==
node_idx
:
data_file_list
.
append
(
sub_file
)
...
...
@@ -167,8 +167,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
def
_alloc_dataset
(
self
,
file_list
):
""" """
dataset
=
fluid
.
DatasetFactory
().
create_dataset
(
self
.
_config
[
'dataset_type'
])
dataset
=
fluid
.
DatasetFactory
().
create_dataset
(
self
.
_config
[
'dataset_type'
])
dataset
.
set_batch_size
(
self
.
_config
[
'batch_size'
])
dataset
.
set_thread
(
self
.
_config
[
'load_thread'
])
dataset
.
set_hdfs_config
(
self
.
_config
[
'fs_name'
],
...
...
@@ -207,8 +207,8 @@ class TimeSplitDatasetHolder(DatasetHolder):
params
[
'node_num'
],
params
[
'node_idx'
])
self
.
_datasets
[
begin_time
]
=
self
.
_alloc_dataset
(
file_list
)
self
.
_datasets
[
begin_time
].
preload_into_memory
(
self
.
_config
[
'preload_thread'
])
self
.
_datasets
[
begin_time
].
preload_into_memory
(
self
.
_config
[
'preload_thread'
])
return
True
return
False
...
...
core/utils/envs.py
浏览文件 @
a5d3f512
...
...
@@ -70,8 +70,8 @@ def set_global_envs(envs):
nests
=
copy
.
deepcopy
(
namespace_nests
)
nests
.
append
(
k
)
fatten_env_namespace
(
nests
,
v
)
elif
(
k
==
"dataset"
or
k
==
"phase"
or
k
==
"runner"
)
and
isinstance
(
v
,
list
):
elif
(
k
==
"dataset"
or
k
==
"phase"
or
k
==
"runner"
)
and
isinstance
(
v
,
list
):
for
i
in
v
:
if
i
.
get
(
"name"
)
is
None
:
raise
ValueError
(
"name must be in dataset list "
,
v
)
...
...
@@ -169,15 +169,14 @@ def pretty_print_envs(envs, header=None):
def
lazy_instance_by_package
(
package
,
class_name
):
try
:
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
instance
=
getattr
(
model_package
,
class_name
)
return
instance
except
Exception
,
err
:
traceback
.
print_exc
()
print
(
'Catch Exception:%s'
%
str
(
err
))
return
None
def
lazy_instance_by_fliename
(
abs
,
class_name
):
...
...
@@ -186,8 +185,8 @@ def lazy_instance_by_fliename(abs, class_name):
sys
.
path
.
append
(
dirname
)
package
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
abs
))[
0
]
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
instance
=
getattr
(
model_package
,
class_name
)
return
instance
except
Exception
,
err
:
...
...
core/utils/util.py
浏览文件 @
a5d3f512
...
...
@@ -101,6 +101,7 @@ def make_datetime(date_str, fmt=None):
return
datetime
.
datetime
.
strptime
(
date_str
,
'%Y%m%d%H%M'
)
return
datetime
.
datetime
.
strptime
(
date_str
,
fmt
)
def
wroker_numric_opt
(
fleet
,
value
,
env
,
opt
):
"""
numric count opt for workers
...
...
@@ -116,6 +117,7 @@ def wroker_numric_opt(fleet, value, env, opt):
fleet
.
_role_maker
.
all_reduce_worker
(
local_value
,
global_value
,
opt
)
return
global_value
[
0
]
def
worker_numric_sum
(
fleet
,
value
,
env
=
"mpi"
):
"""R
"""
...
...
@@ -139,6 +141,7 @@ def worker_numric_max(fleet, value, env="mpi"):
"""
return
wroker_numric_opt
(
fleet
,
value
,
env
,
"max"
)
def
print_log
(
log_str
,
params
):
"""R
"""
...
...
@@ -153,6 +156,7 @@ def print_log(log_str, params):
if
'stdout'
in
params
:
params
[
'stdout'
]
+=
log_str
+
'
\n
'
def
rank0_print
(
log_str
,
fleet
):
"""R
"""
...
...
@@ -171,7 +175,6 @@ class CostPrinter(object):
"""
For count cost time && print cost log
"""
def
__init__
(
self
,
callback
,
callback_params
):
"""R
"""
...
...
@@ -207,7 +210,6 @@ class PathGenerator(object):
"""
generate path with template & runtime variables
"""
def
__init__
(
self
,
config
):
"""R
"""
...
...
@@ -228,8 +230,8 @@ class PathGenerator(object):
"""
if
template_name
in
self
.
_templates
:
if
'time_format'
in
param
:
str
=
param
[
'time_format'
].
strftime
(
self
.
_templates
[
template_name
])
str
=
param
[
'time_format'
].
strftime
(
self
.
_templates
[
template_name
])
return
str
.
format
(
**
param
)
return
self
.
_templates
[
template_name
].
format
(
**
param
)
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录