Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
fa625b30
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
4
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fa625b30
编写于
12月 27, 2019
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
bf089579
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
736 addition
and
0 deletion
+736
-0
paddlepalm/mtl_controller.py
paddlepalm/mtl_controller.py
+736
-0
未找到文件。
paddlepalm/mtl_controller.py
0 → 100755
浏览文件 @
fa625b30
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
sys
import
importlib
import
multiprocessing
from
paddle
import
fluid
from
paddle.fluid
import
layers
import
yaml
import
json
import
logging
import
time
import
numpy
as
np
from
paddlepalm.utils.saver
import
init_pretraining_params
,
init_checkpoint
from
paddlepalm.utils.config_helper
import
PDConfig
from
paddlepalm.utils.print_helper
import
print_dict
from
paddlepalm.utils.reader_helper
import
create_net_inputs
,
create_iterator_fn
,
create_joint_iterator_fn
,
merge_input_attrs
from
paddlepalm.distribute
import
data_feeder
from
default_settings
import
*
from
task_instance
import
TaskInstance
,
check_instances
DEBUG
=
False
VERBOSE
=
0
def
_get_basename
(
f
):
return
os
.
path
.
splitext
(
f
)[
0
]
def
_get_suffix
(
f
):
return
os
.
path
.
splitext
(
f
)[
-
1
]
def
_parse_yaml
(
f
,
asdict
=
True
,
support_cmd_line
=
False
):
assert
os
.
path
.
exists
(
f
),
"file {} not found."
.
format
(
f
)
if
support_cmd_line
:
args
=
PDConfig
(
yaml_file
=
f
,
fuse_args
=
True
)
args
.
build
()
return
args
.
asdict
()
if
asdict
else
args
else
:
if
asdict
:
with
open
(
f
,
"r"
)
as
fin
:
yaml_config
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
SafeLoader
)
return
yaml_config
else
:
raise
NotImplementedError
()
def
_parse_json
(
f
,
asdict
=
True
,
support_cmd_line
=
False
):
assert
os
.
path
.
exists
(
f
),
"file {} not found."
.
format
(
f
)
if
support_cmd_line
:
args
=
PDConfig
(
json_file
=
f
,
fuse_args
=
support_cmd_line
)
args
.
build
()
return
args
.
asdict
()
if
asdict
else
args
else
:
if
asdict
:
with
open
(
f
,
"r"
)
as
fin
:
config
=
json
.
load
(
fin
)
return
config
else
:
raise
NotImplementedError
()
def
_parse_list
(
string
,
astype
=
str
):
assert
isinstance
(
string
,
str
),
"{} is not a string."
.
format
(
string
)
if
','
not
in
string
:
return
[
astype
(
string
)]
string
=
string
.
replace
(
','
,
' '
)
return
[
astype
(
i
)
for
i
in
string
.
split
()]
def
_try_float
(
s
):
try
:
float
(
s
)
return
(
float
(
s
))
except
:
return
s
def
_check_conf
(
conf
,
checklist
=
None
):
assert
isinstance
(
conf
,
dict
),
"{} is not a dict."
.
format
(
conf
)
ret
=
{}
for
k
,
v
in
conf
.
items
():
if
isinstance
(
v
,
str
):
v
=
_try_float
(
v
)
ret
[
k
]
=
v
if
checklist
is
not
None
:
for
k
,
t
in
checklist
:
assert
k
in
ret
,
"required argument {} is NOT exist in config file."
.
format
(
k
)
assert
isintance
(
ret
[
k
],
t
),
"value type of argument {} should be {}"
.
format
(
k
,
t
)
return
ret
# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None
def
_check_io
(
in_attr
,
out_attr
,
strict
=
False
,
in_name
=
"left"
,
out_name
=
"right"
):
for
name
,
attr
in
in_attr
.
items
():
assert
name
in
out_attr
,
in_name
+
': '
+
name
+
' not found in '
+
out_name
if
attr
!=
out_attr
[
name
]:
if
strict
:
raise
ValueError
(
name
+
': shape or dtype not consistent!'
)
else
:
logging
.
warning
(
'{}: shape or dtype not consistent!
\n
{}:
\n
{}
\n
{}:
\n
{}'
.
format
(
name
,
in_name
,
attr
,
out_name
,
out_attr
[
name
]))
def
_merge_conf
(
conf1
,
conf2
,
conf1_first
=
True
,
strict
=
False
):
assert
isinstance
(
conf1
,
dict
),
"{} is not a dict."
.
format
(
conf1
)
assert
isinstance
(
conf2
,
dict
),
"{} is not a dict."
.
format
(
conf2
)
base_conf
=
conf2
if
conf1_first
else
conf1
base_conf
=
base_conf
.
copy
()
new_conf
=
conf1
if
conf1_first
else
conf2
for
k
,
v
in
new_conf
.
items
():
if
k
in
base_conf
:
if
base_conf
[
k
]
!=
v
:
raise
Warning
(
"value of argument {} has been updated to {}."
.
format
(
k
,
v
))
else
:
if
strict
:
continue
base_conf
[
k
]
=
v
return
base_conf
def
_encode_inputs
(
inputs
,
scope_name
,
sep
=
'/'
,
cand_set
=
None
):
outputs
=
{}
for
k
,
v
in
inputs
.
items
():
if
cand_set
is
not
None
:
if
k
in
cand_set
:
outputs
[
k
]
=
v
if
scope_name
+
sep
+
k
in
cand_set
:
outputs
[
scope_name
+
sep
+
k
]
=
v
else
:
outputs
[
scope_name
+
sep
+
k
]
=
v
return
outputs
def
_decode_inputs
(
inputs
,
scope_name
,
sep
=
'/'
,
keep_unk_keys
=
True
):
outputs
=
{}
for
name
,
value
in
inputs
.
items
():
# var for backbone are also available to tasks
if
keep_unk_keys
and
sep
not
in
name
:
outputs
[
name
]
=
value
# var for this inst
if
name
.
startswith
(
scope_name
+
'/'
):
outputs
[
name
[
len
(
scope_name
+
'/'
):]]
=
value
return
outputs
def
_init_env
(
use_gpu
):
if
use_gpu
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
return
fluid
.
Executor
(
place
),
dev_count
def
_fit_attr
(
conf
,
fit_attr
,
strict
=
False
):
for
i
,
attr
in
fit_attr
.
items
():
if
i
not
in
conf
:
if
strict
:
raise
Exception
(
'Argument {} is required to create a controller.'
.
format
(
i
))
else
:
continue
conf
[
i
]
=
attr
(
conf
[
i
])
return
conf
def
create_feed_batch_process_fn
(
net_inputs
):
def
feed_batch_process_fn
(
data
):
temp
=
{}
for
q
,
var
in
net_inputs
.
items
():
if
isinstance
(
var
,
str
)
or
isinstance
(
var
,
unicode
):
temp
[
var
]
=
data
[
q
]
else
:
temp
[
var
.
name
]
=
data
[
q
]
return
temp
return
feed_batch_process_fn
class
Controller
(
object
):
def
__init__
(
self
,
config
,
task_dir
=
'.'
,
for_train
=
True
):
"""
Args:
config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径;
"""
self
.
_for_train
=
for_train
assert
isinstance
(
config
,
str
)
or
isinstance
(
config
,
dict
),
"a config dict or config file path is required to create a Controller."
if
isinstance
(
config
,
str
):
mtl_conf
=
_parse_yaml
(
config
,
support_cmd_line
=
True
)
else
:
mtl_conf
=
config
mtl_conf
=
_check_conf
(
mtl_conf
)
mtl_conf
=
_fit_attr
(
mtl_conf
,
REQUIRED_ARGS
,
strict
=
True
)
mtl_conf
=
_fit_attr
(
mtl_conf
,
OPTIONAL_ARGS
,
strict
=
False
)
exe
,
dev_count
=
_init_env
(
use_gpu
=
mtl_conf
.
get
(
'use_gpu'
,
True
))
self
.
exe
=
exe
self
.
dev_count
=
dev_count
print_dict
(
mtl_conf
,
title
=
'global configuration'
)
# parse task instances and target tags
instnames
=
_parse_list
(
mtl_conf
[
'task_instance'
])
assert
len
(
instnames
)
==
len
(
set
(
instnames
)),
"repeated task_instance is NOT supported."
num_instances
=
len
(
instnames
)
self
.
num_instances
=
num_instances
instname_to_conf
=
{}
instname_to_id
=
{}
for
id
,
instname
in
enumerate
(
instnames
):
instpath
=
os
.
path
.
join
(
task_dir
,
instname
+
'.yaml'
)
conf
=
_parse_yaml
(
instpath
,
support_cmd_line
=
False
)
# conf = _check_conf(conf, TASK_INSTANCE_REQUIRED_ARGS)
conf
=
_check_conf
(
conf
)
temp_conf
=
_merge_conf
(
mtl_conf
,
conf
,
strict
=
True
)
print_dict
(
temp_conf
,
title
=
'{} configuration'
.
format
(
instname
))
conf
=
_merge_conf
(
mtl_conf
,
conf
)
instname_to_conf
[
instname
]
=
conf
instname_to_id
[
instname
]
=
id
# prepare backbone
if
'backbone_config_path'
in
mtl_conf
:
bb_conf
=
_parse_json
(
mtl_conf
[
'backbone_config_path'
])
bb_conf
=
_merge_conf
(
mtl_conf
,
bb_conf
)
else
:
bb_conf
=
mtl_conf
print_dict
(
bb_conf
,
title
=
'backbone configuration'
.
format
(
instname
))
bb_name
=
mtl_conf
[
'backbone'
]
bb_mod
=
importlib
.
import_module
(
BACKBONE_DIR
+
'.'
+
bb_name
)
Backbone
=
getattr
(
bb_mod
,
'Model'
)
# create task instances
instances
=
[]
for
name
in
instnames
:
instances
.
append
(
TaskInstance
(
name
,
instname_to_id
[
name
],
instname_to_conf
[
name
]))
check_instances
(
instances
)
# parse target_tag
if
'target_tag'
in
mtl_conf
:
target_tag
=
str
(
mtl_conf
[
'target_tag'
])
tags
=
_parse_list
(
target_tag
,
astype
=
int
)
assert
len
(
tags
)
==
len
(
instnames
),
"number of target_tag is NOT consistent with that in task_instance."
for
tag
,
inst
in
zip
(
tags
,
instances
):
inst
.
is_target
=
tag
else
:
tags
=
[
i
.
is_target
for
i
in
instances
]
num_targets
=
sum
(
tags
)
num_auxes
=
num_instances
-
num_targets
# parse mix ratios
if
'mix_ratio'
in
mtl_conf
:
mix_ratio
=
str
(
mtl_conf
[
'mix_ratio'
])
mrs
=
_parse_list
(
mix_ratio
,
astype
=
float
)
assert
len
(
mrs
)
==
num_instances
,
"number of mix_ratios is NOT consistent with num_instances."
else
:
mrs
=
[
1.0
]
*
num_instances
for
mr
,
inst
in
zip
(
mrs
,
instances
):
inst
.
mix_ratio
=
mr
# parse task layer reuse tags
instname_to_reusehost
=
{
i
:
i
for
i
in
instnames
}
if
'task_reuse_tag'
in
mtl_conf
:
tags
=
_parse_list
(
mtl_conf
[
'task_reuse_tag'
],
astype
=
int
)
assert
len
(
tags
)
==
num_targets
,
'number of reuse_tags is NOT consistent with number of instances.'
else
:
tags
=
[]
mapper
=
{}
for
inst
in
instances
:
history
=
set
()
history
.
add
(
inst
.
name
)
cur_inst
=
inst
while
True
:
if
cur_inst
.
task_reuse_scope
in
history
:
mapper
[
inst
.
name
]
=
len
(
tags
)
break
elif
cur_inst
.
task_reuse_scope
in
mapper
:
mapper
[
inst
.
name
]
=
mapper
[
cur_inst
.
task_reuse_scope
]
break
else
:
cur_inst
=
name_to_instance
[
cur_inst
.
task_reuse_scope
]
history
.
add
(
cur_inst
.
name
)
tags
.
append
(
mapper
[
inst
.
name
])
for
i
in
range
(
1
,
num_instances
):
for
j
in
range
(
i
):
if
tags
[
i
]
==
tags
[
j
]:
assert
instances
[
i
].
Paradigm
==
\
instances
[
j
].
Paradigm
,
\
"paradigm of reuse tasks should be consistent"
instances
[
i
].
task_reuse_scope
=
instances
[
j
].
name
break
self
.
instances
=
instances
self
.
mrs
=
mrs
self
.
Backbone
=
Backbone
self
.
bb_conf
=
bb_conf
self
.
bb_name
=
bb_name
self
.
has_init_train
=
False
self
.
has_init_pred
=
False
if
self
.
_for_train
:
print
(
"initialing for training..."
)
self
.
_init_train
()
self
.
has_init_train
=
True
def
_init_train
(
self
):
instances
=
self
.
instances
Backbone
=
self
.
Backbone
bb_conf
=
self
.
bb_conf
bb_name
=
self
.
bb_name
dev_count
=
self
.
dev_count
num_instances
=
len
(
instances
)
mrs
=
self
.
mrs
# set first_target/main task instance
main_inst
=
None
for
inst
in
instances
:
if
inst
.
is_target
:
main_inst
=
inst
inst
.
is_first_target
=
True
break
main_conf
=
main_inst
.
config
if
not
os
.
path
.
exists
(
main_conf
[
'save_path'
]):
os
.
makedirs
(
main_conf
[
'save_path'
])
os
.
makedirs
(
os
.
path
.
join
(
main_conf
[
'save_path'
],
'ckpt'
))
# prepare backbone
train_backbone
=
Backbone
(
bb_conf
,
phase
=
'train'
)
pred_backbone
=
Backbone
(
bb_conf
,
phase
=
'pred'
)
# create reader, task
# then check i/o across reader, backbone and task_layer
task_attrs
=
[]
pred_task_attrs
=
[]
for
inst
in
instances
:
train_reader
=
inst
.
Reader
(
inst
.
config
,
phase
=
'train'
)
inst
.
reader
[
'train'
]
=
train_reader
train_parad
=
inst
.
Paradigm
(
inst
.
config
,
phase
=
'train'
,
backbone_config
=
bb_conf
)
inst
.
task_layer
[
'train'
]
=
train_parad
task_attr_from_reader
=
_encode_inputs
(
train_parad
.
inputs_attrs
[
'reader'
],
inst
.
name
)
task_attrs
.
append
(
task_attr_from_reader
)
_check_io
(
train_backbone
.
inputs_attr
,
train_reader
.
outputs_attr
,
in_name
=
bb_name
+
'_backbone'
,
out_name
=
'reader.train'
)
_check_io
(
train_parad
.
inputs_attrs
[
'reader'
],
train_reader
.
outputs_attr
,
in_name
=
'task_paradigm.train.reader'
,
out_name
=
'reader.train'
)
_check_io
(
train_parad
.
inputs_attrs
[
'backbone'
],
train_backbone
.
outputs_attr
,
in_name
=
'task_paradigm.train.backbone'
,
out_name
=
bb_name
+
'_backbone'
)
if
inst
.
is_target
:
if
'pred_file'
not
in
inst
.
config
:
inst
.
config
[
'pred_file'
]
=
''
pred_reader
=
inst
.
Reader
(
inst
.
config
,
phase
=
'pred'
)
pred_parad
=
inst
.
Paradigm
(
inst
.
config
,
phase
=
'pred'
,
backbone_config
=
bb_conf
)
inst
.
task_layer
[
'pred'
]
=
pred_parad
task_attr_from_reader
=
_encode_inputs
(
pred_parad
.
inputs_attrs
[
'reader'
],
inst
.
name
)
pred_task_attrs
.
append
(
task_attr_from_reader
)
_check_io
(
pred_backbone
.
inputs_attr
,
pred_reader
.
outputs_attr
,
in_name
=
bb_name
+
'_backbone'
,
out_name
=
'reader.pred'
)
_check_io
(
pred_parad
.
inputs_attrs
[
'reader'
],
pred_reader
.
outputs_attr
,
in_name
=
'task_paradigm.pred.reader'
,
out_name
=
'reader.pred'
)
_check_io
(
pred_parad
.
inputs_attrs
[
'backbone'
],
pred_backbone
.
outputs_attr
,
in_name
=
'task_paradigm.pred.backbone'
,
out_name
=
bb_name
+
'_backbone'
)
# merge reader input attrs from backbone and task_instances
joint_input_names
,
joint_shape_and_dtypes
,
name_to_position
=
merge_input_attrs
(
train_backbone
.
inputs_attr
,
task_attrs
)
pred_joint_input_names
,
pred_joint_shape_and_dtypes
,
_
=
merge_input_attrs
(
pred_backbone
.
inputs_attr
,
pred_task_attrs
,
insert_taskid
=
False
,
insert_batchsize
=
False
,
insert_seqlen
=
False
,
insert_batchsize_x_seqlen
=
False
)
# shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
if
DEBUG
:
print
(
'----- for debug -----'
)
print
(
'joint input names:'
)
print
(
joint_input_names
)
print
(
'joint input shape and dtypes:'
)
print
(
joint_shape_and_dtypes
)
# load data
for
inst
in
instances
:
print
(
inst
.
name
+
": preparing data..."
,
end
=
''
)
inst
.
reader
[
'train'
].
load_data
()
print
(
'ok!'
)
# merge dataset iterators and create net input vars
iterators
=
[]
prefixes
=
[]
mrs
=
[]
for
inst
in
instances
:
iterators
.
append
(
inst
.
reader
[
'train'
].
iterator
())
prefixes
.
append
(
inst
.
name
)
mrs
.
append
(
inst
.
mix_ratio
)
joint_iterator_fn
=
create_joint_iterator_fn
(
iterators
,
prefixes
,
joint_shape_and_dtypes
,
mrs
,
name_to_position
,
dev_count
=
dev_count
,
verbose
=
VERBOSE
,
return_type
=
'dict'
)
self
.
_joint_iterator_fn
=
joint_iterator_fn
input_attrs
=
[[
i
,
j
,
k
]
for
i
,
(
j
,
k
)
in
zip
(
joint_input_names
,
joint_shape_and_dtypes
)]
pred_input_attrs
=
[[
i
,
j
,
k
]
for
i
,
(
j
,
k
)
in
zip
(
pred_joint_input_names
,
pred_joint_shape_and_dtypes
)]
# net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3)
net_inputs
=
create_net_inputs
(
input_attrs
,
async
=
False
)
self
.
_net_inputs
=
net_inputs
# build backbone and task layers
train_prog
=
fluid
.
default_main_program
()
train_init_prog
=
fluid
.
default_startup_program
()
bb_output_vars
=
train_backbone
.
build
(
net_inputs
,
scope_name
=
'__paddlepalm_'
)
assert
sorted
(
bb_output_vars
.
keys
())
==
sorted
(
train_backbone
.
outputs_attr
.
keys
())
pred_prog
=
fluid
.
Program
()
pred_init_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
=
pred_prog
,
startup_program
=
pred_init_prog
):
pred_net_inputs
=
create_net_inputs
(
pred_input_attrs
)
pred_bb_output_vars
=
pred_backbone
.
build
(
pred_net_inputs
,
scope_name
=
'__paddlepalm_'
)
fluid
.
framework
.
switch_main_program
(
train_prog
)
fluid
.
framework
.
switch_startup_program
(
train_init_prog
)
task_output_vars
=
{}
for
inst
in
instances
:
task_inputs
=
{
'backbone'
:
bb_output_vars
}
task_inputs_from_reader
=
_decode_inputs
(
net_inputs
,
inst
.
name
)
task_inputs
[
'reader'
]
=
task_inputs_from_reader
scope
=
inst
.
task_reuse_scope
+
'/'
with
fluid
.
unique_name
.
guard
(
scope
):
output_vars
=
inst
.
build_task_layer
(
task_inputs
,
phase
=
'train'
,
scope
=
scope
)
output_vars
=
{
inst
.
name
+
'/'
+
key
:
val
for
key
,
val
in
output_vars
.
items
()}
old
=
len
(
task_output_vars
)
# for debug
task_output_vars
.
update
(
output_vars
)
assert
len
(
task_output_vars
)
-
old
==
len
(
output_vars
)
# for debug
# prepare predict vars for saving inference model
if
inst
.
is_target
:
with
fluid
.
program_guard
(
pred_prog
,
pred_init_prog
):
cur_inputs
=
_decode_inputs
(
pred_net_inputs
,
inst
.
name
)
inst
.
pred_input
=
cur_inputs
pred_task_inputs
=
{
'backbone'
:
pred_bb_output_vars
,
'reader'
:
cur_inputs
}
scope
=
inst
.
task_reuse_scope
+
'/'
with
fluid
.
unique_name
.
guard
(
scope
):
inst
.
build_task_layer
(
pred_task_inputs
,
phase
=
'pred'
,
scope
=
scope
)
bb_fetches
=
{
k
:
v
.
name
for
k
,
v
in
bb_output_vars
.
items
()}
task_fetches
=
{
k
:
v
.
name
for
k
,
v
in
task_output_vars
.
items
()}
fetches
=
task_fetches
fetches
[
'__task_id'
]
=
net_inputs
[
'__task_id'
].
name
# compute loss
task_id_var
=
net_inputs
[
'__task_id'
]
task_id_vec
=
fluid
.
one_hot
(
task_id_var
,
num_instances
)
losses
=
fluid
.
layers
.
concat
([
task_output_vars
[
inst
.
name
+
'/loss'
]
for
inst
in
instances
],
axis
=
0
)
loss
=
layers
.
reduce_sum
(
task_id_vec
*
losses
)
main_reader
=
main_inst
.
reader
[
'train'
]
num_examples
=
main_reader
.
num_examples
for
inst
in
instances
:
max_train_steps
=
int
(
main_conf
[
'num_epochs'
]
*
inst
.
mix_ratio
*
(
num_examples
//
main_conf
[
'batch_size'
]
//
dev_count
))
if
inst
.
is_target
:
print
(
'{}: expected train steps {}.'
.
format
(
inst
.
name
,
max_train_steps
))
inst
.
steps_pur_epoch
=
inst
.
reader
[
'train'
].
num_examples
//
main_conf
[
'batch_size'
]
//
dev_count
inst
.
expected_train_steps
=
max_train_steps
global_max_train_steps
=
int
(
main_conf
[
'num_epochs'
]
*
sum
(
mrs
)
*
(
num_examples
//
main_conf
[
'batch_size'
]
//
dev_count
))
print
(
'Estimated overall train steps {}.'
.
format
(
global_max_train_steps
))
if
'warmup_proportion'
in
main_conf
and
main_conf
[
'warmup_proportion'
]
>
0
:
warmup_steps
=
int
(
global_max_train_steps
*
main_conf
[
'warmup_proportion'
])
print
(
'Warmup steps: '
+
str
(
warmup_steps
))
else
:
warmup_steps
=
0
# build optimizer
if
'optimizer'
in
main_conf
:
optim_mod
=
importlib
.
import_module
(
OPTIMIZER_DIR
+
'.'
+
main_conf
[
'optimizer'
])
optimize
=
getattr
(
optim_mod
,
OPTIMIZE_METHOD
)
optimize
(
loss
,
main_conf
,
max_train_steps
,
warmup_steps
,
fluid
.
default_main_program
())
loss
.
persistable
=
True
if
main_conf
.
get
(
'use_ema'
,
False
):
assert
'ema_decay'
in
main_conf
,
"ema_decay should be set when use_ema is enabled."
ema
=
fluid
.
optimizer
.
ExponentialMovingAverage
(
main_conf
[
'ema_decay'
])
ema
.
update
()
# prepare for train
self
.
train_backbone
=
train_backbone
self
.
train_program
=
fluid
.
CompiledProgram
(
fluid
.
default_main_program
()).
with_data_parallel
(
loss_name
=
loss
.
name
)
self
.
saver_program
=
fluid
.
default_main_program
()
self
.
main_inst
=
main_inst
self
.
fetches
=
fetches
self
.
has_init_train
=
True
self
.
has_init_pred
=
True
self
.
exe
.
run
(
fluid
.
default_startup_program
())
print
(
"
\n
Randomly initialize parameters...
\n
"
)
def
_init_pred
(
self
,
instance
,
infer_model_path
):
inst
=
instance
if
'pred_output_path'
not
in
inst
.
config
:
inst
.
config
[
'pred_output_path'
]
=
os
.
path
.
join
(
inst
.
config
.
get
(
'save_path'
,
'.'
),
inst
.
name
)
if
not
os
.
path
.
exists
(
inst
.
config
[
'pred_output_path'
]):
os
.
makedirs
(
inst
.
config
[
'pred_output_path'
])
pred_backbone
=
self
.
Backbone
(
self
.
bb_conf
,
phase
=
'pred'
)
pred_parad
=
inst
.
Paradigm
(
inst
.
config
,
phase
=
'pred'
,
backbone_config
=
self
.
bb_conf
)
inst
.
task_layer
[
'pred'
]
=
pred_parad
pred_joint_input_names
,
pred_joint_shape_and_dtypes
,
name_to_position
=
merge_input_attrs
(
pred_backbone
.
inputs_attr
,
inst
.
task_layer
[
'pred'
].
inputs_attrs
[
'reader'
],
insert_taskid
=
False
,
insert_batchsize
=
False
,
insert_seqlen
=
False
,
insert_batchsize_x_seqlen
=
False
)
pred_prog
=
inst
.
load
(
infer_model_path
)
pred_prog
=
fluid
.
CompiledProgram
(
pred_prog
).
with_data_parallel
()
if
inst
.
reader
[
'pred'
]
is
None
:
pred_reader
=
inst
.
Reader
(
inst
.
config
,
phase
=
'pred'
)
inst
.
reader
[
'pred'
]
=
pred_reader
return
pred_prog
def
load_pretrain
(
self
,
pretrain_path
=
None
):
# load pretrain model (or ckpt)
if
pretrain_path
is
None
:
assert
'pretrain_path'
in
self
.
main_conf
,
"pretrain_path NOT set."
pretrain_path
=
self
.
main_conf
[
'pretrain_path'
]
init_pretraining_params
(
self
.
exe
,
pretrain_path
,
main_program
=
fluid
.
default_startup_program
())
def
train
(
self
):
if
not
self
.
has_init_train
:
self
.
_init_train
()
self
.
has_init_train
=
True
instances
=
self
.
instances
num_instances
=
self
.
num_instances
main_inst
=
self
.
main_inst
main_conf
=
main_inst
.
config
backbone
=
self
.
train_backbone
train_program
=
self
.
train_program
saver_program
=
self
.
saver_program
fetches
=
self
.
fetches
finish
=
[]
for
inst
in
instances
:
if
inst
.
is_target
:
if
inst
.
expected_train_steps
>
0
:
finish
.
append
(
False
)
else
:
finish
.
append
(
True
)
print
(
inst
.
name
+
': train finished!'
)
inst
.
save
()
def
train_finish
():
for
inst
in
instances
:
if
inst
.
is_target
:
if
not
inst
.
train_finish
:
return
False
return
True
# do training
fetch_names
,
fetch_list
=
zip
(
*
fetches
.
items
())
main_step
=
0
# only count for main task
global_step
=
0
# count for all tasks
epoch
=
0
time_begin
=
time
.
time
()
backbone_buffer
=
[]
feed_batch_process_fn
=
create_feed_batch_process_fn
(
self
.
_net_inputs
)
distribute_feeder
=
data_feeder
(
self
.
_joint_iterator_fn
,
feed_batch_process_fn
)
# palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2)
while
not
train_finish
():
feed
,
mask
=
next
(
distribute_feeder
)
rt_outputs
=
self
.
exe
.
run
(
train_program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
while
mask
.
pop
()
==
False
:
rt_outputs
.
pop
()
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
fetch_names
,
rt_outputs
)}
rt_task_id
=
np
.
squeeze
(
rt_outputs
[
'__task_id'
]).
tolist
()
rt_task_id
=
rt_task_id
[
0
]
if
isinstance
(
rt_task_id
,
list
)
else
rt_task_id
cur_task
=
instances
[
rt_task_id
]
backbone_rt_outputs
=
{
k
:
v
for
k
,
v
in
rt_outputs
.
items
()
if
'/'
not
in
k
}
backbone_buffer
.
append
(
backbone
.
postprocess
(
backbone_rt_outputs
))
task_rt_outputs
=
{
k
[
len
(
cur_task
.
name
+
'/'
):]:
v
for
k
,
v
in
rt_outputs
.
items
()
if
k
.
startswith
(
cur_task
.
name
+
'/'
)}
instances
[
rt_task_id
].
task_layer
[
'train'
].
postprocess
(
task_rt_outputs
)
global_step
+=
1
cur_task
.
cur_train_step
+=
1
cur_task_global_step
=
cur_task
.
cur_train_step
+
cur_task
.
cur_train_epoch
*
cur_task
.
steps_pur_epoch
if
cur_task
.
is_target
and
cur_task
.
save_infermodel_every_n_steps
>
0
and
cur_task_global_step
%
cur_task
.
save_infermodel_every_n_steps
==
0
:
cur_task
.
save
(
suffix
=
'.step'
+
str
(
cur_task_global_step
))
if
global_step
%
main_conf
.
get
(
'print_every_n_steps'
,
5
)
==
0
:
loss
=
rt_outputs
[
cur_task
.
name
+
'/loss'
]
loss
=
np
.
mean
(
np
.
squeeze
(
loss
)).
tolist
()
time_end
=
time
.
time
()
time_cost
=
time_end
-
time_begin
print
(
"Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s"
.
format
(
global_step
,
cur_task
.
name
,
cur_task
.
cur_train_step
,
cur_task
.
steps_pur_epoch
,
cur_task
.
cur_train_epoch
,
loss
,
main_conf
.
get
(
'print_every_n_steps'
,
5
)
/
time_cost
))
time_begin
=
time
.
time
()
if
cur_task
.
train_finish
and
cur_task
.
cur_train_step
+
cur_task
.
cur_train_epoch
*
cur_task
.
steps_pur_epoch
==
cur_task
.
expected_train_steps
:
print
(
cur_task
.
name
+
': train finished!'
)
cur_task
.
save
()
if
'save_ckpt_every_n_steps'
in
main_conf
and
global_step
%
main_conf
[
'save_ckpt_every_n_steps'
]
==
0
:
save_path
=
os
.
path
.
join
(
main_conf
[
'save_path'
],
'ckpt'
,
"step_"
+
str
(
global_step
))
fluid
.
io
.
save_persistables
(
self
.
exe
,
save_path
,
saver_program
)
print
(
'checkpoint has been saved at '
+
save_path
)
save_path
=
os
.
path
.
join
(
main_conf
[
'save_path'
],
'ckpt'
,
"step_"
+
str
(
global_step
))
fluid
.
io
.
save_persistables
(
self
.
exe
,
save_path
,
saver_program
)
print
(
'checkpoint has been saved at '
+
save_path
)
print
(
"ALL tasks train finished, exiting..."
)
def
pred
(
self
,
task_instance
,
inference_model_dir
=
None
):
if
self
.
_for_train
:
raise
Exception
(
'This controller is a trainer. Please build a new controller with for_train=False for predicting.'
)
assert
isinstance
(
task_instance
,
str
)
if
isinstance
(
inference_model_dir
,
str
):
assert
os
.
path
.
exists
(
inference_model_dir
),
inference_model_dir
+
" not found."
# if not self.has_init_pred and inference_model_dir is None:
# raise ValueError('infer_model_path is required for prediction.')
if
inference_model_dir
is
None
:
assert
'save_path'
in
self
.
mtl_conf
,
"one of the `inference_model_dir` and 'save_path' should be set to load inference model."
inference_model_dir
=
os
.
path
.
join
(
self
.
mtl_conf
[
'save_path'
],
task_instance
,
'infer_model'
)
instance
=
None
for
inst
in
self
.
instances
:
if
inst
.
name
==
task_instance
:
instance
=
inst
break
if
instance
is
None
:
raise
ValueError
(
task_instance
+
' is not a valid task_instance.'
)
pred_prog
=
self
.
_init_pred
(
instance
,
inference_model_dir
)
inst
=
instance
print
(
inst
.
name
+
": loading data..."
)
inst
.
reader
[
'pred'
].
load_data
()
fetch_names
,
fetch_vars
=
inst
.
pred_fetch_list
print
(
'predicting...'
)
feed_batch_process_fn
=
create_feed_batch_process_fn
(
inst
.
pred_input
)
distribute_feeder
=
data_feeder
(
inst
.
reader
[
'pred'
].
iterator
,
feed_batch_process_fn
,
prefetch_steps
=
1
)
buf
=
[]
for
feed
,
mask
in
distribute_feeder
:
print
(
'before run'
)
rt_outputs
=
self
.
exe
.
run
(
pred_prog
,
feed
,
fetch_vars
)
print
(
'after run'
)
splited_rt_outputs
=
[]
for
item
in
rt_outputs
:
splited_rt_outputs
.
append
(
np
.
split
(
item
,
len
(
mask
)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
print
(
mask
)
while
mask
.
pop
()
==
False
:
print
(
mask
)
for
item
in
splited_rt_outputs
:
item
.
pop
()
rt_outputs
=
[]
print
(
'cancat'
)
for
item
in
splited_rt_outputs
:
rt_outputs
.
append
(
np
.
concatenate
(
item
))
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
fetch_names
,
rt_outputs
)}
inst
.
postprocess
(
rt_outputs
,
phase
=
'pred'
)
print
(
'leave feeder'
)
if
inst
.
task_layer
[
'pred'
].
epoch_inputs_attrs
:
reader_outputs
=
inst
.
reader
[
'pred'
].
get_epoch_outputs
()
else
:
reader_outputs
=
None
print
(
'epoch postprocess'
)
inst
.
epoch_postprocess
({
'reader'
:
reader_outputs
},
phase
=
'pred'
)
if
__name__
==
'__main__'
:
assert
len
(
sys
.
argv
)
==
2
,
"Usage: python mtl_controller.py <mtl_conf_path>"
conf_path
=
sys
.
argv
[
1
]
del
sys
.
argv
[
1
]
controller
=
Controller
(
conf_path
)
if
controller
.
main_conf
[
'do_train'
]:
controller
.
train
()
__all__
=
[
"Controller"
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录