Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c62b8d6d
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c62b8d6d
编写于
6月 08, 2020
作者:
X
xionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revert it
上级
cdc0f7c9
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
122 addition
and
186 deletion
+122
-186
core/metrics/auc_metrics.py
core/metrics/auc_metrics.py
+1
-1
core/modules/modul/build.py
core/modules/modul/build.py
+7
-2
core/trainer.py
core/trainer.py
+23
-4
core/utils/dataset_holder.py
core/utils/dataset_holder.py
+9
-3
core/utils/envs.py
core/utils/envs.py
+23
-13
core/utils/util.py
core/utils/util.py
+59
-163
未找到文件。
core/metrics/auc_metrics.py
浏览文件 @
c62b8d6d
...
...
@@ -66,7 +66,7 @@ class AUCMetric(Metric):
old_metric_shape
=
np
.
array
(
metric
.
shape
)
metric
=
metric
.
reshape
(
-
1
)
global_metric
=
np
.
copy
(
metric
)
*
0
self
.
fleet
.
_role_maker
.
_node_type_comm
.
Allreduce
(
metric
,
global_metric
)
self
.
fleet
.
_role_maker
.
all_reduce_worker
(
metric
,
global_metric
)
global_metric
=
global_metric
.
reshape
(
old_metric_shape
)
return
global_metric
[
0
]
...
...
core/modules/modul/build.py
浏览文件 @
c62b8d6d
...
...
@@ -33,8 +33,13 @@ def create(config):
model
=
None
if
config
[
'mode'
]
==
'fluid'
:
if
config
[
'layer_file'
].
endswith
(
".py"
):
model_class
=
envs
.
lazy_instance_by_fliename
(
config
[
'layer_file'
],
"Model"
)
model
=
model_class
(
config
)
else
:
model
=
YamlModel
(
config
)
model
.
train_net
()
model
.
train
()
return
model
...
...
core/trainer.py
浏览文件 @
c62b8d6d
...
...
@@ -17,6 +17,7 @@ import os
import
time
import
sys
import
yaml
import
traceback
from
paddle
import
fluid
...
...
@@ -51,10 +52,18 @@ class Trainer(object):
Return:
None : run a processor for this status
"""
if
context
[
'status'
]
in
self
.
_status_processor
:
status
=
context
[
'status'
]
try
:
if
status
in
self
.
_status_processor
:
self
.
_status_processor
[
context
[
'status'
]](
context
)
else
:
self
.
other_status_processor
(
context
)
except
Exception
,
err
:
traceback
.
print_exc
()
print
(
'Catch Exception:%s'
%
str
(
err
))
sys
.
stdout
.
flush
()
self
.
_context
[
'is_exit'
]
=
self
.
handle_processor_exception
(
status
,
context
,
err
)
def
other_status_processor
(
self
,
context
):
"""
...
...
@@ -65,6 +74,16 @@ class Trainer(object):
print
(
'unknow context_status:%s, do nothing'
%
context
[
'status'
])
time
.
sleep
(
60
)
def
handle_processor_exception
(
self
,
status
,
context
,
exception
):
"""
when exception throwed from processor, will call this func to handle it
Return:
bool exit_app or not
"""
print
(
'Exit app. catch exception in precoss status:%s, except:%s'
\
%
(
context
[
'status'
],
str
(
exception
)))
return
True
def
reload_train_context
(
self
):
"""
context maybe update timely, reload for update
...
...
core/utils/dataset_holder.py
浏览文件 @
c62b8d6d
...
...
@@ -71,7 +71,7 @@ class TimeSplitDatasetHolder(DatasetHolder):
"""
init data root_path, time_split_interval, data_path_format
"""
Dataset
.
__init__
(
self
,
config
)
Dataset
Holder
.
__init__
(
self
,
config
)
if
'data_donefile'
not
in
config
or
config
[
'data_donefile'
]
is
None
:
config
[
'data_donefile'
]
=
config
[
'data_path'
]
+
"/to.hadoop.done"
self
.
_path_generator
=
util
.
PathGenerator
({
...
...
@@ -153,6 +153,12 @@ class TimeSplitDatasetHolder(DatasetHolder):
if
not
sub_file_name
.
startswith
(
self
.
_config
[
'filename_prefix'
]):
continue
postfix
=
sub_file_name
.
split
(
self
.
_config
[
'filename_prefix'
])[
1
]
if
postfix
.
isdigit
():
if
int
(
postfix
)
%
node_num
==
node_idx
:
data_file_list
.
append
(
sub_file
)
else
:
if
hash
(
sub_file_name
)
%
node_num
==
node_idx
:
data_file_list
.
append
(
sub_file
)
time_window_mins
=
time_window_mins
-
self
.
_split_interval
...
...
core/utils/envs.py
浏览文件 @
c62b8d6d
...
...
@@ -18,6 +18,7 @@ import copy
import
os
import
socket
import
sys
import
traceback
global_envs
=
{}
...
...
@@ -167,14 +168,19 @@ def pretty_print_envs(envs, header=None):
def
lazy_instance_by_package
(
package
,
class_name
):
models
=
get_global_env
(
"train.model.models"
)
try
:
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
instance
=
getattr
(
model_package
,
class_name
)
return
instance
except
Exception
,
err
:
traceback
.
print_exc
()
print
(
'Catch Exception:%s'
%
str
(
err
))
return
None
def
lazy_instance_by_fliename
(
abs
,
class_name
):
try
:
dirname
=
os
.
path
.
dirname
(
abs
)
sys
.
path
.
append
(
dirname
)
package
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
abs
))[
0
]
...
...
@@ -183,6 +189,10 @@ def lazy_instance_by_fliename(abs, class_name):
globals
(),
locals
(),
package
.
split
(
"."
))
instance
=
getattr
(
model_package
,
class_name
)
return
instance
except
Exception
,
err
:
traceback
.
print_exc
()
print
(
'Catch Exception:%s'
%
str
(
err
))
return
None
def
get_platform
():
...
...
core/utils/util.py
浏览文件 @
c62b8d6d
...
...
@@ -14,8 +14,9 @@
import
datetime
import
os
import
sys
import
time
import
numpy
as
np
from
paddle
import
fluid
from
paddlerec.core.utils
import
fs
as
fs
...
...
@@ -101,10 +102,65 @@ def make_datetime(date_str, fmt=None):
return
datetime
.
datetime
.
strptime
(
date_str
,
fmt
)
def
rank0_print
(
log_str
):
def
wroker_numric_opt
(
fleet
,
value
,
env
,
opt
):
"""
numric count opt for workers
Args:
value: value for count
env: mpi/gloo
opt: count operator, SUM/MAX/MIN/AVG
Return:
count result
"""
local_value
=
np
.
array
([
value
])
global_value
=
np
.
copy
(
local_value
)
*
0
fleet
.
_role_maker
.
all_reduce_worker
(
local_value
,
global_value
,
opt
)
return
global_value
[
0
]
def
worker_numric_sum
(
fleet
,
value
,
env
=
"mpi"
):
"""R
"""
return
wroker_numric_opt
(
fleet
,
value
,
env
,
"sum"
)
def
worker_numric_avg
(
fleet
,
value
,
env
=
"mpi"
):
"""R
"""
return
worker_numric_sum
(
fleet
,
value
,
env
)
/
fleet
.
worker_num
()
def
worker_numric_min
(
fleet
,
value
,
env
=
"mpi"
):
"""R
"""
print_log
(
log_str
,
{
'master'
:
True
})
return
wroker_numric_opt
(
fleet
,
value
,
env
,
"min"
)
def
worker_numric_max
(
fleet
,
value
,
env
=
"mpi"
):
"""R
"""
return
wroker_numric_opt
(
fleet
,
value
,
env
,
"max"
)
def
print_log
(
log_str
,
params
):
"""R
"""
time_str
=
time
.
strftime
(
"[%Y-%m-%d %H:%M:%S]"
,
time
.
localtime
())
log_str
=
time_str
+
" "
+
log_str
if
'master'
in
params
and
params
[
'master'
]:
if
'index'
in
params
and
params
[
'index'
]
==
0
:
print
(
log_str
)
else
:
print
(
log_str
)
sys
.
stdout
.
flush
()
if
'stdout'
in
params
:
params
[
'stdout'
]
+=
log_str
+
'
\n
'
def
rank0_print
(
log_str
,
fleet
):
"""R
"""
print_log
(
log_str
,
{
'master'
:
True
,
'index'
:
fleet
.
worker_index
()})
def
print_cost
(
cost
,
params
):
...
...
@@ -182,163 +238,3 @@ class PathGenerator(object):
return
self
.
_templates
[
template_name
].
format
(
**
param
)
else
:
return
""
class
TimeTrainPass
(
object
):
"""
timely pass
define pass time_interval && start_time && end_time
"""
def
__init__
(
self
,
global_config
):
"""R
"""
self
.
_config
=
global_config
[
'epoch'
]
if
'+'
in
self
.
_config
[
'days'
]:
day_str
=
self
.
_config
[
'days'
].
replace
(
' '
,
''
)
day_fields
=
day_str
.
split
(
'+'
)
self
.
_begin_day
=
make_datetime
(
day_fields
[
0
].
strip
())
if
len
(
day_fields
)
==
1
or
len
(
day_fields
[
1
])
==
0
:
# 100 years, meaning to continuous running
self
.
_end_day
=
self
.
_begin_day
+
datetime
.
timedelta
(
days
=
36500
)
else
:
# example: 2020212+10
run_day
=
int
(
day_fields
[
1
].
strip
())
self
.
_end_day
=
self
.
_begin_day
+
datetime
.
timedelta
(
days
=
run_day
)
else
:
# example: {20191001..20191031}
days
=
os
.
popen
(
"echo -n "
+
self
.
_config
[
'days'
]).
read
().
split
(
" "
)
self
.
_begin_day
=
make_datetime
(
days
[
0
])
self
.
_end_day
=
make_datetime
(
days
[
len
(
days
)
-
1
])
self
.
_checkpoint_interval
=
self
.
_config
[
'checkpoint_interval'
]
self
.
_dump_inference_interval
=
self
.
_config
[
'dump_inference_interval'
]
self
.
_interval_per_pass
=
self
.
_config
[
'train_time_interval'
]
# train N min data per pass
self
.
_pass_id
=
0
self
.
_inference_pass_id
=
0
self
.
_pass_donefile_handler
=
None
if
'pass_donefile_name'
in
self
.
_config
:
self
.
_train_pass_donefile
=
global_config
[
'output_path'
]
+
'/'
+
self
.
_config
[
'pass_donefile_name'
]
if
fs
.
is_afs_path
(
self
.
_train_pass_donefile
):
self
.
_pass_donefile_handler
=
fs
.
FileHandler
(
global_config
[
'io'
][
'afs'
])
else
:
self
.
_pass_donefile_handler
=
fs
.
FileHandler
(
global_config
[
'io'
][
'local_fs'
])
last_done
=
self
.
_pass_donefile_handler
.
cat
(
self
.
_train_pass_donefile
).
strip
().
split
(
'
\n
'
)[
-
1
]
done_fileds
=
last_done
.
split
(
'
\t
'
)
if
len
(
done_fileds
)
>
4
:
self
.
_base_key
=
done_fileds
[
1
]
self
.
_checkpoint_model_path
=
done_fileds
[
2
]
self
.
_checkpoint_pass_id
=
int
(
done_fileds
[
3
])
self
.
_inference_pass_id
=
int
(
done_fileds
[
4
])
self
.
init_pass_by_id
(
done_fileds
[
0
],
self
.
_checkpoint_pass_id
)
def
max_pass_num_day
(
self
):
"""R
"""
return
24
*
60
/
self
.
_interval_per_pass
def
save_train_progress
(
self
,
day
,
pass_id
,
base_key
,
model_path
,
is_checkpoint
):
"""R
"""
if
is_checkpoint
:
self
.
_checkpoint_pass_id
=
pass_id
self
.
_checkpoint_model_path
=
model_path
done_content
=
"%s
\t
%s
\t
%s
\t
%s
\t
%d
\n
"
%
(
day
,
base_key
,
self
.
_checkpoint_model_path
,
self
.
_checkpoint_pass_id
,
pass_id
)
self
.
_pass_donefile_handler
.
write
(
done_content
,
self
.
_train_pass_donefile
,
'a'
)
pass
def
init_pass_by_id
(
self
,
date_str
,
pass_id
):
"""
init pass context with pass_id
Args:
date_str: example "20200110"
pass_id(int): pass_id of date
"""
date_time
=
make_datetime
(
date_str
)
if
pass_id
<
1
:
pass_id
=
0
if
(
date_time
-
self
.
_begin_day
).
total_seconds
()
>
0
:
self
.
_begin_day
=
date_time
self
.
_pass_id
=
pass_id
mins
=
self
.
_interval_per_pass
*
(
pass_id
-
1
)
self
.
_current_train_time
=
date_time
+
datetime
.
timedelta
(
minutes
=
mins
)
def
init_pass_by_time
(
self
,
datetime_str
):
"""
init pass context with datetime
Args:
date_str: example "20200110000" -> "%Y%m%d%H%M"
"""
self
.
_current_train_time
=
make_datetime
(
datetime_str
)
minus
=
self
.
_current_train_time
.
hour
*
60
+
self
.
_current_train_time
.
minute
self
.
_pass_id
=
minus
/
self
.
_interval_per_pass
+
1
def
current_pass
(
self
):
"""R
"""
return
self
.
_pass_id
def
next
(
self
):
"""R
"""
has_next
=
True
old_pass_id
=
self
.
_pass_id
if
self
.
_pass_id
<
1
:
self
.
init_pass_by_time
(
self
.
_begin_day
.
strftime
(
"%Y%m%d%H%M"
))
else
:
next_time
=
self
.
_current_train_time
+
datetime
.
timedelta
(
minutes
=
self
.
_interval_per_pass
)
if
(
next_time
-
self
.
_end_day
).
total_seconds
()
>
0
:
has_next
=
False
else
:
self
.
init_pass_by_time
(
next_time
.
strftime
(
"%Y%m%d%H%M"
))
if
has_next
and
(
self
.
_inference_pass_id
<
self
.
_pass_id
or
self
.
_pass_id
<
old_pass_id
):
self
.
_inference_pass_id
=
self
.
_pass_id
-
1
return
has_next
def
is_checkpoint_pass
(
self
,
pass_id
):
"""R
"""
if
pass_id
<
1
:
return
True
if
pass_id
==
self
.
max_pass_num_day
():
return
False
if
pass_id
%
self
.
_checkpoint_interval
==
0
:
return
True
return
False
def
need_dump_inference
(
self
,
pass_id
):
"""R
"""
return
self
.
_inference_pass_id
<
pass_id
and
pass_id
%
self
.
_dump_inference_interval
==
0
def
date
(
self
,
delta_day
=
0
):
"""
get train date
Args:
delta_day(int): n day afer current_train_date
Return:
date(current_train_time + delta_day)
"""
return
(
self
.
_current_train_time
+
datetime
.
timedelta
(
days
=
delta_day
)
).
strftime
(
"%Y%m%d"
)
def
timestamp
(
self
,
delta_day
=
0
):
"""R
"""
return
(
self
.
_current_train_time
+
datetime
.
timedelta
(
days
=
delta_day
)
).
timestamp
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录