Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e59463ef
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e59463ef
编写于
2月 16, 2020
作者:
1
123malin
提交者:
GitHub
2月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop, add distributed tools (#22623)
上级
1aab3e61
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
698 addition
and
1 deletion
+698
-1
python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py
...n/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py
+1
-0
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
+81
-1
python/paddle/fluid/incubate/fleet/utils/utils.py
python/paddle/fluid/incubate/fleet/utils/utils.py
+428
-0
python/paddle/fluid/tests/unittests/test_fleet_utils.py
python/paddle/fluid/tests/unittests/test_fleet_utils.py
+188
-0
未找到文件。
python/paddle/fluid/incubate/fleet/utils/fleet_barrier_util.py
浏览文件 @
e59463ef
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.contrib.utils
import
HDFSClient
from
paddle.fluid.contrib.utils
import
HDFSClient
import
os
import
os
import
time
def
check_all_trainers_ready
(
ready_path
,
epoch
):
def
check_all_trainers_ready
(
ready_path
,
epoch
):
...
...
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
浏览文件 @
e59463ef
...
@@ -23,15 +23,19 @@ import sys
...
@@ -23,15 +23,19 @@ import sys
import
time
import
time
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
fleet
as
fleet_pslib
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
as
fleet_transpiler
from
.
import
hdfs
from
.
import
hdfs
from
.hdfs
import
*
from
.hdfs
import
*
from
.
import
utils
__all__
=
[
"FleetUtil"
]
__all__
=
[
"FleetUtil"
]
_logger
=
get_logger
(
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
fleet
=
fleet_pslib
class
FleetUtil
(
object
):
class
FleetUtil
(
object
):
"""
"""
...
@@ -46,6 +50,16 @@ class FleetUtil(object):
...
@@ -46,6 +50,16 @@ class FleetUtil(object):
"""
"""
def
__init__
(
self
,
mode
=
"pslib"
):
global
fleet
if
mode
==
"pslib"
:
fleet
=
fleet_pslib
elif
mode
==
"transpiler"
:
fleet
=
fleet_transpiler
else
:
raise
ValueError
(
"Please choose one mode from [
\"
pslib
\"
,
\"
transpiler
\"
]"
)
def
rank0_print
(
self
,
s
):
def
rank0_print
(
self
,
s
):
"""
"""
Worker of rank 0 print some log.
Worker of rank 0 print some log.
...
@@ -1535,3 +1549,69 @@ class FleetUtil(object):
...
@@ -1535,3 +1549,69 @@ class FleetUtil(object):
(
print_prefix
,
auc
,
bucket_error
,
mae
,
rmse
,
(
print_prefix
,
auc
,
bucket_error
,
mae
,
rmse
,
actual_ctr
,
predicted_ctr
,
copc
,
mean_predict_qvalue
,
actual_ctr
,
predicted_ctr
,
copc
,
mean_predict_qvalue
,
total_ins_num
))
total_ins_num
))
def
program_type_trans
(
self
,
prog_dir
,
prog_fn
,
is_text
):
return
utils
.
program_type_trans
(
prog_dir
,
prog_fn
,
is_text
)
def
draw_from_program_file
(
self
,
model_filename
,
is_text
,
output_dir
,
output_filename
):
"""draw program from file"""
program
=
utils
.
load_program
(
model_filename
,
is_text
)
utils
.
graphviz
(
program
.
global_block
(),
output_dir
,
output_filename
)
def
draw_from_program
(
self
,
program
,
output_dir
,
output_name
):
"""draw Program"""
utils
.
graphviz
(
program
.
global_block
(),
output_dir
,
output_name
)
def
check_two_programs
(
self
,
config
):
train_prog
=
utils
.
load_program
(
config
.
train_prog_path
,
config
.
is_text_train_program
)
pruned_prog
=
utils
.
load_program
(
config
.
pruned_prog_path
,
config
.
is_text_pruned_program
)
if
config
.
draw
:
pruned_dir
=
os
.
path
.
dirname
(
config
.
pruned_prog_path
)
self
.
draw_from_program
(
pruned_prog
,
pruned_dir
,
config
.
draw_out_name
)
res
=
utils
.
check_pruned_program_vars
(
train_prog
,
pruned_prog
)
if
res
:
_logger
.
info
(
"check_programs succeed."
)
else
:
_logger
.
info
(
"check_programs failed. pruned program and train program not match!"
)
return
res
def
check_vars_and_dump
(
self
,
config
):
_logger
.
info
(
"start check_vars_and_dump."
)
results
=
utils
.
check_saved_vars_try_dump
(
config
.
dump_model_dir
,
config
.
dump_program_filename
,
config
.
is_text_dump_program
,
config
.
feed_config
,
config
.
fetch_config
,
config
.
batch_size
,
config
.
save_params_filename
)
_logger
.
info
(
"check_vars_and_dump succeed."
)
return
results
def
parse_program_proto
(
self
,
prog_path
,
is_text
,
output_dir
):
"""
Parse program.proto into a more readable format.
This function will generate three files:
output_dir/vars_all.log,
output_dir/vars_persistable.log,
output_dir/ops.log.
Args:
prog_path(str): proto file path to be parsed.
is_text(bool): proto file is human-readale format or not(binary).
output_dir(str): output dir.
Examples:
.. code-block:: python
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
fleet_util = FleetUtil()
program_path = "./program.pbtxt"
is_text = True
output_dir = "/tmp/"
fleet_util.parse_program_proto(program_path, is_text, output_dir)
"""
program
=
utils
.
load_program
(
prog_path
,
is_text
)
utils
.
parse_program
(
program
,
output_dir
)
python/paddle/fluid/incubate/fleet/utils/utils.py
0 → 100644
浏览文件 @
e59463ef
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
,
absolute_import
import
os
import
sys
import
logging
import
subprocess
import
numpy
as
np
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.log_helper
import
get_logger
from
google.protobuf
import
text_format
from
paddle.fluid
import
debugger
from
paddle.fluid.framework
import
Program
from
paddle.fluid.proto
import
framework_pb2
__all__
=
[
"load_program"
,
"save_program"
,
"program_type_trans"
,
"check_saved_vars_try_dump"
,
"parse_program"
,
"check_pruned_program_vars"
,
"graphviz"
]
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
persistable_vars_out_fn
=
"vars_persistable.log"
all_vars_out_fn
=
"vars_all.log"
ops_out_fn
=
"ops.log"
feed_fetch_type_list
=
[
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
,
core
.
VarDesc
.
VarType
.
FETCH_LIST
]
not_expected_op_types
=
[
"lookup_table"
]
def
load_program
(
model_filename
,
is_text
=
False
):
if
is_text
:
return
load_program_text
(
model_filename
)
return
load_program_binary
(
model_filename
)
def
load_program_binary
(
model_filename
):
"""load program from binary string file"""
with
open
(
model_filename
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
return
Program
.
parse_from_string
(
program_desc_str
)
def
load_program_text
(
model_filename
):
"""load program from human-readable text file"""
with
open
(
model_filename
,
"r"
)
as
f
:
program_desc_text
=
f
.
read
()
prog_desc
=
framework_pb2
.
ProgramDesc
()
text_format
.
Merge
(
program_desc_text
,
prog_desc
)
return
Program
.
parse_from_string
(
prog_desc
.
SerializeToString
())
def
save_program
(
program
,
model_filename
=
'__model__'
,
is_text
=
False
):
if
is_text
:
with
open
(
model_filename
,
"w"
)
as
f
:
f
.
write
(
str
(
program
))
else
:
with
open
(
model_filename
,
"wb"
)
as
f
:
f
.
write
(
program
.
desc
.
serialize_to_string
())
def
check_pruned_program_vars
(
train_prog
,
pruned_prog
):
is_match
=
True
pruned_vars
=
[(
v
.
name
,
v
)
for
v
in
pruned_prog
.
list_vars
()
if
fluid
.
io
.
is_persistable
(
v
)]
pruned_vars
=
OrderedDict
(
pruned_vars
)
pruned_vars_name
=
[
name
for
name
in
pruned_vars
]
logger
.
info
(
"persistable vars in pruned program: {}"
.
format
(
pruned_vars_name
))
for
var_name
in
pruned_vars
:
var
=
pruned_vars
[
var_name
]
# feed and fetch op is added in pruned program when pruning, not need to be found in train program
if
var
.
type
in
feed_fetch_type_list
:
break
try
:
train_prog_var
=
train_prog
.
global_block
().
var
(
var_name
)
except
ValueError
as
e
:
logger
.
error
(
"not find variable '%s' in train program. please check pruning."
%
var_name
)
logger
.
error
(
e
)
continue
if
var
.
shape
!=
train_prog_var
.
shape
or
var
.
dtype
!=
train_prog_var
.
dtype
:
logger
.
error
(
"variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}"
.
format
(
var_name
,
var
.
shape
,
var
.
dtype
,
train_prog_var
.
shape
,
train_prog_var
.
dtype
))
is_match
=
False
return
is_match
def
graphviz
(
block
,
output_dir
=
""
,
filename
=
'debug'
):
dot_path
=
os
.
path
.
join
(
output_dir
,
filename
+
'.dot'
)
pdf_path
=
os
.
path
.
join
(
output_dir
,
filename
+
'.pdf'
)
debugger
.
draw_block_graphviz
(
block
,
path
=
dot_path
)
cmd
=
[
"dot"
,
"-Tpdf"
,
dot_path
,
"-o"
,
pdf_path
]
p
=
subprocess
.
Popen
(
cmd
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
p
.
wait
()
def
program_type_trans
(
prog_dir
,
prog_fn
,
is_text
):
prog
=
load_program
(
os
.
path
.
join
(
prog_dir
,
prog_fn
),
is_text
)
prog_out_fn
=
prog_fn
+
".bin"
if
is_text
else
prog_fn
+
".pbtxt"
save_program
(
prog
,
os
.
path
.
join
(
prog_dir
,
prog_out_fn
),
1
-
is_text
)
return
prog_out_fn
def
append_save_op
(
block
,
var
,
path
):
block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
path
})
def
append_load_op
(
block
,
var
,
path
):
block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'file_path'
:
path
})
def
save_var
(
np_array
,
var_name
,
shape_list
,
dtype
,
save_path
):
program
=
fluid
.
Program
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
with
fluid
.
program_guard
(
program
):
d0_data
=
fluid
.
layers
.
data
(
var_name
,
shape
=
shape_list
,
dtype
=
dtype
)
append_save_op
(
program
.
global_block
(),
d0_data
,
save_path
)
exe
.
run
(
feed
=
{
var_name
:
np_array
},
fetch_list
=
[])
def
load_var
(
var_name
,
shape_list
,
dtype
,
save_path
):
program
=
fluid
.
Program
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
with
fluid
.
program_guard
(
program
):
d0_data
=
fluid
.
layers
.
data
(
var_name
,
shape
=
shape_list
,
dtype
=
dtype
)
append_load_op
(
program
.
global_block
(),
d0_data
,
save_path
)
outs
=
exe
.
run
(
feed
=
{},
fetch_list
=
[
d0_data
])
return
outs
def
reader
(
batch_size
,
fn
,
dim
):
data
=
[]
if
isinstance
(
dim
,
list
)
or
isinstance
(
dim
,
tuple
):
shape
=
list
(
dim
)
_temp
=
1
for
x
in
dim
:
_temp
=
_temp
*
x
dim
=
_temp
else
:
shape
=
[
dim
]
shape
=
[
batch_size
]
+
shape
dim
=
dim
*
batch_size
for
line
in
open
(
fn
,
'r'
):
fields
=
line
.
strip
().
split
(
' '
)
fields
=
[
float
(
d
)
for
d
in
fields
]
while
len
(
fields
)
>=
dim
:
tmp
=
fields
[:
dim
]
fields
=
fields
[
dim
:]
data
.
append
(
np
.
array
(
tmp
).
reshape
(
shape
))
return
data
def
feed_gen
(
batch_size
,
feeded_vars_dims
,
feeded_vars_filelist
):
batch_feed
=
[]
for
i
,
fn
in
enumerate
(
feeded_vars_filelist
):
batch_feed
.
append
(
reader
(
batch_size
,
fn
,
feeded_vars_dims
[
i
]))
return
batch_feed
def
try_load_model_vars
(
dump_dir
,
dump_prog_fn
,
is_text_dump_program
,
batch_size
,
feed_config
,
fetch_config
,
save_filename
,
saved_params
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
if
is_text_dump_program
:
dump_prog_fn
=
program_type_trans
(
dump_dir
,
dump_prog_fn
,
is_text_dump_program
)
inference_program
,
feed_target_names
,
fetch_targets
=
\
fluid
.
io
.
load_inference_model
(
dump_dir
,
exe
,
model_filename
=
dump_prog_fn
,
params_filename
=
save_filename
)
# check program vars and saved vars shape
orig_para_shape
=
{
each_var
.
name
:
tuple
(
each_var
.
desc
.
shape
())
for
each_var
in
saved_params
}
for
each_var
in
saved_params
:
var_temp
=
fluid
.
global_scope
().
find_var
(
each_var
.
name
)
assert
var_temp
!=
None
,
"can't not find var: "
+
each_var
.
name
new_shape
=
(
np
.
array
(
var_temp
.
get_tensor
())).
shape
assert
each_var
.
name
in
orig_para_shape
,
each_var
.
name
+
"MUST in var list"
orig_shape
=
orig_para_shape
.
get
(
each_var
.
name
)
if
new_shape
!=
orig_shape
:
raise
RuntimeError
(
"Shape not matching: the Program requires a parameter with a shape of ({}), "
"while the loaded parameter (namely [ {} ]) has a shape of ({})."
.
format
(
orig_shape
,
each_var
.
name
,
new_shape
))
# check feed/fetch vars in program and config
fetch_targets_names
=
[
v
.
name
for
v
in
fetch_targets
]
if
not
feed_target_names
:
logger
.
warning
(
"no feed targets in program."
)
if
not
fetch_targets_names
:
logger
.
warning
(
"no fetch targets in program."
)
fetch_list
=
fetch_targets
feed_name_list
=
feed_target_names
if
feed_config
.
feeded_vars_names
is
not
None
and
feed_target_names
!=
feed_config
.
feeded_vars_names
:
logger
.
warning
(
"feed vars in program and config are diff: feed in program: {}. feed in config {}."
.
format
(
feed_target_names
,
feed_config
.
feeded_vars_names
))
feed_name_list
=
feed_config
.
feeded_vars_names
# remove feed op in inference_program. new feed op will be added in exe.run
global_block
=
inference_program
.
global_block
()
need_to_remove_op_index
=
[]
for
i
,
op
in
enumerate
(
global_block
.
ops
):
op
.
desc
.
set_is_target
(
False
)
if
op
.
type
==
"feed"
:
# only remove feed op here
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
global_block
.
_remove_op
(
index
)
if
fetch_config
.
fetch_vars_names
is
not
None
and
fetch_targets_names
!=
fetch_config
.
fetch_vars_names
:
logger
.
warning
(
"fetch vars in program and config are diff: fetch in program: {}. fetch in config {}."
.
format
(
fetch_targets_names
,
fetch_config
.
fetch_vars_names
))
fetch_list
=
[
inference_program
.
global_block
().
var
(
i
)
for
i
in
fetch_config
.
fetch_vars_names
]
# remove fetch op in inference_program. new fetch op will be added in exe.run
global_block
=
inference_program
.
global_block
()
need_to_remove_op_index
=
[]
for
i
,
op
in
enumerate
(
global_block
.
ops
):
op
.
desc
.
set_is_target
(
False
)
if
op
.
type
==
"fetch"
:
# only remove fetch op here
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
global_block
.
_remove_op
(
index
)
# if fetch_list have lod tensor
return_numpy
=
all
([
v
.
lod_level
==
0
for
v
in
fetch_list
])
# try dump fetch_targets
feed_tensors
=
[]
assert
len
(
feed_config
.
feeded_vars_names
)
==
len
(
feed_config
.
feeded_vars_dims
)
==
len
(
feed_config
.
feeded_vars_types
)
# check program vars and feed tensor shape in config
for
i
in
range
(
len
(
feed_config
.
feeded_vars_names
)):
var
=
inference_program
.
global_block
().
var
(
feed_config
.
feeded_vars_names
[
i
])
if
not
isinstance
(
feed_config
.
feeded_vars_dims
[
i
],
(
list
,
tuple
)):
tensor_shape
=
(
feed_config
.
feeded_vars_dims
[
i
],
)
else
:
tensor_shape
=
tuple
(
feed_config
.
feeded_vars_dims
[
i
])
feed_config
.
feeded_vars_dims
[
i
]
=
tensor_shape
var_shape
=
var
.
shape
[
1
:]
if
tensor_shape
!=
var_shape
:
raise
RuntimeError
(
"feed variable '{}' shape not match. infer program shape: {}. feed tensor shape: {}"
.
format
(
feed_config
.
feeded_vars_names
[
i
],
var_shape
,
tensor_shape
))
if
not
feed_config
.
feeded_vars_filelist
:
logger
.
info
(
"generate random feed vars."
)
for
i
in
range
(
len
(
feed_config
.
feeded_vars_names
)):
var
=
inference_program
.
global_block
().
var
(
feed_config
.
feeded_vars_names
[
i
])
# create fake feed tensor. if lod_level > 1, should create_lod_tensor()
if
var
.
lod_level
==
0
:
feed_tensors
.
append
(
np
.
array
(
np
.
random
.
random
(
tuple
([
batch_size
]
+
list
(
feed_config
.
feeded_vars_dims
[
i
]))),
dtype
=
feed_config
.
feeded_vars_types
[
i
]))
elif
var
.
lod_level
==
1
:
t
=
np
.
array
(
np
.
random
.
random
(
tuple
([
batch_size
]
+
list
(
feed_config
.
feeded_vars_dims
[
i
]))),
dtype
=
feed_config
.
feeded_vars_types
[
i
])
feed_tensors
.
append
(
fluid
.
create_lod_tensor
(
t
,
[[
1
]
*
batch_size
],
place
))
else
:
raise
RuntimeError
(
"vars with lod_level >= 2 is not supported now in this infer program check tool."
)
results
=
exe
.
run
(
inference_program
,
feed
=
{
name
:
feed_tensors
[
i
]
for
i
,
name
in
enumerate
(
feed_name_list
)
},
fetch_list
=
fetch_list
,
return_numpy
=
return_numpy
)
else
:
logger
.
info
(
"load feed vars from files: {}."
.
format
(
feed_config
.
feeded_vars_filelist
))
feed_vars
=
[
inference_program
.
global_block
().
var
(
feed_config
.
feeded_vars_names
[
i
])
for
i
in
range
(
len
(
feed_config
.
feeded_vars_names
))
]
feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_vars
,
place
=
place
)
batch_feed
=
feed_gen
(
batch_size
,
feed_config
.
feeded_vars_dims
,
feed_config
.
feeded_vars_filelist
)
slots
=
[
batch_feed
]
results
=
exe
.
run
(
inference_program
,
feed
=
feeder
.
feed
(
slots
),
fetch_list
=
fetch_list
,
return_numpy
=
return_numpy
)
for
i
,
v
in
enumerate
(
fetch_list
):
logger
.
info
(
"fetch_targets name: %s"
%
v
.
name
)
logger
.
info
(
"fetch_targets: {}"
.
format
(
results
[
i
]))
return
results
def
check_not_expected_ops
(
prog
):
op_types_set
=
set
()
for
op
in
prog
.
global_block
().
ops
:
if
op
.
type
in
not_expected_op_types
and
op
.
type
not
in
op_types_set
:
logger
.
warning
(
"find op type '{}' in program, please check if your program is pruned correctly !"
.
format
(
op
.
type
))
op_types_set
.
add
(
op
.
type
)
def
check_saved_vars_try_dump
(
dump_dir
,
dump_prog_fn
,
is_text_dump_program
,
feed_config
,
fetch_config
,
batch_size
=
1
,
save_filename
=
None
):
dump_prog
=
load_program
(
os
.
path
.
join
(
dump_dir
,
dump_prog_fn
),
is_text_dump_program
)
saved_params
=
[
v
for
v
in
dump_prog
.
list_vars
()
if
fluid
.
io
.
is_persistable
(
v
)
]
logger
.
info
(
"persistable vars in dump program: {}"
.
format
(
[
v
.
name
for
v
in
saved_params
]))
check_not_expected_ops
(
dump_prog
)
return
try_load_model_vars
(
dump_dir
,
dump_prog_fn
,
is_text_dump_program
,
batch_size
,
feed_config
,
fetch_config
,
save_filename
,
saved_params
)
def
parse_program
(
program
,
output_dir
):
# persistable vars
output
=
{}
persistable_vars
=
[
v
for
v
in
program
.
list_vars
()
if
fluid
.
io
.
is_persistable
(
v
)
]
output
[
"persistable_vars"
]
=
[{
'name'
:
str
(
v
.
name
),
'shape'
:
str
(
v
.
shape
),
'lod_level'
:
int
(
v
.
lod_level
),
'dtype'
:
str
(
v
.
dtype
),
'type'
:
str
(
v
.
type
)
}
for
v
in
persistable_vars
]
with
open
(
os
.
path
.
join
(
output_dir
,
persistable_vars_out_fn
),
'w'
)
as
f
:
f
.
write
(
"persistable vars:
\n
"
)
for
var
in
output
[
"persistable_vars"
]:
f
.
write
(
str
(
var
))
f
.
write
(
"
\n
"
)
# all vars
all_vars
=
[
v
for
v
in
program
.
list_vars
()]
output
[
"all_vars"
]
=
[{
'name'
:
str
(
v
.
name
),
'shape'
:
str
(
v
.
shape
),
'lod_level'
:
int
(
v
.
lod_level
),
'dtype'
:
str
(
v
.
dtype
)
}
if
v
.
type
not
in
feed_fetch_type_list
else
{
'name'
:
str
(
v
.
name
),
'type'
:
str
(
v
.
type
)
}
for
v
in
all_vars
]
with
open
(
os
.
path
.
join
(
output_dir
,
all_vars_out_fn
),
'w'
)
as
f
:
f
.
write
(
"all vars:
\n
"
)
for
var
in
output
[
"all_vars"
]:
f
.
write
(
str
(
var
))
f
.
write
(
"
\n
"
)
# ops
ops
=
program
.
global_block
().
ops
output
[
"ops"
]
=
[{
'type'
:
op
.
type
,
'input_arg_names'
:
str
(
op
.
input_arg_names
),
'output_arg_names'
:
str
(
op
.
output_arg_names
)
}
for
op
in
ops
]
with
open
(
os
.
path
.
join
(
output_dir
,
ops_out_fn
),
'w'
)
as
f
:
f
.
write
(
"ops:
\n
"
)
for
op
in
output
[
"ops"
]:
f
.
write
(
str
(
op
))
f
.
write
(
"
\n
"
)
python/paddle/fluid/tests/unittests/test_fleet_utils.py
浏览文件 @
e59463ef
...
@@ -13,14 +13,43 @@
...
@@ -13,14 +13,43 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
unittest
import
unittest
import
numpy
as
np
import
tarfile
import
tempfile
import
os
import
sys
from
paddle.dataset.common
import
download
,
DATA_HOME
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.utils.fleet_barrier_util
import
check_all_trainers_ready
from
paddle.fluid.incubate.fleet.utils.fleet_barrier_util
import
check_all_trainers_ready
from
paddle.fluid.incubate.fleet.utils.fleet_util
import
FleetUtil
import
paddle.fluid.incubate.fleet.utils.utils
as
utils
class
TestFleetUtils
(
unittest
.
TestCase
):
class
TestFleetUtils
(
unittest
.
TestCase
):
proto_data_url
=
"https://fleet.bj.bcebos.com/fleet_util_data.tgz"
proto_data_md5
=
"59b7f12fd9dc24b64ae8e4629523a92a"
module_name
=
"fleet_util_data"
pruned_dir
=
os
.
path
.
join
(
"fleet_util_data"
,
"pruned_model"
)
train_dir
=
os
.
path
.
join
(
"fleet_util_data"
,
"train_program"
)
def
download_files
(
self
):
path
=
download
(
self
.
proto_data_url
,
self
.
module_name
,
self
.
proto_data_md5
)
print
(
'data is downloaded at '
+
path
)
tar
=
tarfile
.
open
(
path
)
unzip_folder
=
tempfile
.
mkdtemp
()
tar
.
extractall
(
unzip_folder
)
return
unzip_folder
def
test_fleet_util_init
(
self
):
fleet_util_pslib
=
FleetUtil
()
fleet_util_transpiler
=
FleetUtil
(
mode
=
"transpiler"
)
self
.
assertRaises
(
Exception
,
FleetUtil
,
"other"
)
def
test_fleet_barrier
(
self
):
def
test_fleet_barrier
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
current_id
=
0
,
...
@@ -30,6 +59,165 @@ class TestFleetUtils(unittest.TestCase):
...
@@ -30,6 +59,165 @@ class TestFleetUtils(unittest.TestCase):
fleet
.
init
(
role
)
fleet
.
init
(
role
)
check_all_trainers_ready
(
"/ready_path/"
,
0
)
check_all_trainers_ready
(
"/ready_path/"
,
0
)
def
test_program_type_trans
(
self
):
data_dir
=
self
.
download_files
()
program_dir
=
os
.
path
.
join
(
data_dir
,
self
.
pruned_dir
)
text_program
=
"pruned_main_program.pbtxt"
binary_program
=
"pruned_main_program.bin"
fleet_util
=
FleetUtil
()
text_to_binary
=
fleet_util
.
program_type_trans
(
program_dir
,
text_program
,
True
)
binary_to_text
=
fleet_util
.
program_type_trans
(
program_dir
,
binary_program
,
False
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
program_dir
,
text_to_binary
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
program_dir
,
binary_to_text
)))
def
test_parse_program_proto
(
self
):
data_dir
=
self
.
download_files
()
parse_program_file_path
=
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
pruned_dir
,
"pruned_main_program.pbtxt"
))
is_text_parse_program
=
True
parse_output_dir
=
os
.
path
.
join
(
data_dir
,
self
.
pruned_dir
)
fleet_util
=
FleetUtil
()
fleet_util
.
parse_program_proto
(
parse_program_file_path
,
is_text_parse_program
,
parse_output_dir
)
ops_log
=
os
.
path
.
join
(
parse_output_dir
,
"ops.log"
)
vars_log
=
os
.
path
.
join
(
parse_output_dir
,
"vars_all.log"
)
vars_persistable
=
os
.
path
.
join
(
parse_output_dir
,
"vars_persistable.log"
)
self
.
assertTrue
(
os
.
path
.
exists
(
ops_log
))
self
.
assertTrue
(
os
.
path
.
exists
(
vars_log
))
self
.
assertTrue
(
os
.
path
.
exists
(
vars_persistable
))
def
test_check_vars_and_dump
(
self
):
data_dir
=
self
.
download_files
()
class
config
:
pass
feed_config
=
config
()
feed_config
.
feeded_vars_names
=
[
'concat_1.tmp_0'
,
'concat_2.tmp_0'
]
feed_config
.
feeded_vars_dims
=
[
682
,
1199
]
feed_config
.
feeded_vars_types
=
[
np
.
float32
,
np
.
float32
]
feed_config
.
feeded_vars_filelist
=
[
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
pruned_dir
,
"concat_1"
)),
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
pruned_dir
,
"concat_2"
))
]
fetch_config
=
config
()
fetch_config
.
fetch_vars_names
=
[
'similarity_norm.tmp_0'
]
conf
=
config
()
conf
.
batch_size
=
1
conf
.
feed_config
=
feed_config
conf
.
fetch_config
=
fetch_config
conf
.
dump_model_dir
=
os
.
path
.
join
(
data_dir
,
self
.
pruned_dir
)
conf
.
dump_program_filename
=
"pruned_main_program.pbtxt"
conf
.
is_text_dump_program
=
True
conf
.
save_params_filename
=
None
fleet_util
=
FleetUtil
()
# test saved var's shape
conf
.
dump_program_filename
=
"pruned_main_program.save_var_shape_not_match"
self
.
assertRaises
(
Exception
,
fleet_util
.
check_vars_and_dump
,
conf
)
# test program.proto without feed_op and fetch_op
conf
.
dump_program_filename
=
"pruned_main_program.no_feed_fetch"
results
=
fleet_util
.
check_vars_and_dump
(
conf
)
self
.
assertTrue
(
len
(
results
)
==
1
)
np
.
testing
.
assert_array_almost_equal
(
results
[
0
],
np
.
array
(
[[
3.0590223e-07
]],
dtype
=
np
.
float32
))
# test feed_var's shape
conf
.
dump_program_filename
=
"pruned_main_program.feed_var_shape_not_match"
self
.
assertRaises
(
Exception
,
fleet_util
.
check_vars_and_dump
,
conf
)
# test correct case with feed_vars_filelist
conf
.
dump_program_filename
=
"pruned_main_program.pbtxt"
results
=
fleet_util
.
check_vars_and_dump
(
conf
)
self
.
assertTrue
(
len
(
results
)
==
1
)
np
.
testing
.
assert_array_almost_equal
(
results
[
0
],
np
.
array
(
[[
3.0590223e-07
]],
dtype
=
np
.
float32
))
# test correct case without feed_vars_filelist
conf
.
feed_config
.
feeded_vars_filelist
=
None
# test feed var with lod_level >= 2
conf
.
dump_program_filename
=
"pruned_main_program.feed_lod2"
self
.
assertRaises
(
Exception
,
fleet_util
.
check_vars_and_dump
,
conf
)
conf
.
dump_program_filename
=
"pruned_main_program.pbtxt"
results
=
fleet_util
.
check_vars_and_dump
(
conf
)
self
.
assertTrue
(
len
(
results
)
==
1
)
def
test_check_two_programs
(
self
):
data_dir
=
self
.
download_files
()
class
config
:
pass
conf
=
config
()
conf
.
train_prog_path
=
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
train_dir
,
"join_main_program.pbtxt"
))
conf
.
is_text_train_program
=
True
# test not match
conf
.
pruned_prog_path
=
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
pruned_dir
,
"pruned_main_program.save_var_shape_not_match"
))
conf
.
is_text_pruned_program
=
True
conf
.
draw
=
False
fleet_util
=
FleetUtil
()
res
=
fleet_util
.
check_two_programs
(
conf
)
self
.
assertFalse
(
res
)
# test match
conf
.
pruned_prog_path
=
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
pruned_dir
,
"pruned_main_program.pbtxt"
))
if
sys
.
platform
==
'win32'
or
sys
.
platform
==
'sys.platform'
:
conf
.
draw
=
False
else
:
conf
.
draw
=
True
conf
.
draw_out_name
=
"pruned_check"
res
=
fleet_util
.
check_two_programs
(
conf
)
self
.
assertTrue
(
res
)
def
test_draw_program
(
self
):
if
sys
.
platform
==
'win32'
or
sys
.
platform
==
'sys.platform'
:
pass
else
:
data_dir
=
self
.
download_files
()
program_path
=
os
.
path
.
join
(
data_dir
,
os
.
path
.
join
(
self
.
train_dir
,
"join_main_program.pbtxt"
))
is_text
=
True
program
=
utils
.
load_program
(
program_path
,
is_text
)
output_dir
=
os
.
path
.
join
(
data_dir
,
self
.
train_dir
)
output_filename_1
=
"draw_prog_1"
output_filename_2
=
"draw_prog_2"
fleet_util
=
FleetUtil
()
fleet_util
.
draw_from_program_file
(
program_path
,
is_text
,
output_dir
,
output_filename_1
)
fleet_util
.
draw_from_program
(
program
,
output_dir
,
output_filename_2
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
output_dir
,
output_filename_1
+
".dot"
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
output_dir
,
output_filename_1
+
".pdf"
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
output_dir
,
output_filename_2
+
".dot"
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
output_dir
,
output_filename_2
+
".pdf"
)))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录