Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
34bd0045
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
34bd0045
编写于
9月 22, 2020
作者:
M
MrChengmo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine fleetrun.ps_launch
上级
f4c750d7
变更
7
展开全部
显示空白变更内容
内联
并排
Showing
7 changed file
with
580 addition
and
245 deletion
+580
-245
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+4
-6
python/paddle/distributed/fleet/base/role_maker.py
python/paddle/distributed/fleet/base/role_maker.py
+1
-1
python/paddle/distributed/fleet/launch.py
python/paddle/distributed/fleet/launch.py
+69
-220
python/paddle/distributed/fleet/launch_utils.py
python/paddle/distributed/fleet/launch_utils.py
+462
-9
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
...buted/fleet/meta_optimizers/parameter_server_optimizer.py
+13
-0
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
...dle/distributed/fleet/runtime/parameter_server_runtime.py
+7
-7
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
...paddle/fluid/incubate/fleet/parameter_server/ir/public.py
+24
-2
未找到文件。
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
34bd0045
...
...
@@ -112,10 +112,6 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
template
<
typename
T
>
void
RecvLodTensor
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
...
...
@@ -125,10 +121,12 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if
(
rpc_ctx
.
origin_varnames
.
size
()
==
1
&&
rpc_ctx
.
splited_varnames
.
size
()
==
1
)
{
auto
varname
=
rpc_ctx
.
origin_varnames
[
0
];
VLOG
(
4
)
<<
"recv "
<<
varname
<<
" from "
<<
rpc_ctx
.
epmap
[
0
];
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
VLOG
(
4
)
<<
"recv "
<<
varname
<<
" from "
<<
rpc_ctx
.
epmap
[
0
]
<<
" in gpu? "
<<
platform
::
is_gpu_place
(
place
);
rets
.
push_back
(
rpc_client
->
AsyncGetVarNoBarrier
(
rpc_ctx
.
epmap
[
0
],
cpu_ctx
,
scope
,
varname
,
varname
));
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
rets
[
i
]
->
Wait
(),
0U
,
...
...
python/paddle/distributed/fleet/base/role_maker.py
浏览文件 @
34bd0045
...
...
@@ -508,7 +508,7 @@ class RoleMakerBase(object):
and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainerr
"""
assert
self
.
_heter_trainer_endpoints
!=
[]
return
self
.
_heter_trainer_endpoints
[(
self
.
_current_id
+
1
)
%
return
self
.
_heter_trainer_endpoints
[(
self
.
_current_id
)
%
self
.
_heter_worker_num
()]
def
_get_heter_worker_device
(
self
):
...
...
python/paddle/distributed/fleet/launch.py
浏览文件 @
34bd0045
...
...
@@ -89,14 +89,40 @@ def _parse_args():
description
=
'''start paddle training using multi-process mode.
see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
'''
)
base_group
=
parser
.
add_argument_group
(
"Base Parameters"
)
base_group
.
add_argument
(
"-d"
,
"--distributed_mode"
,
type
=
str
,
choices
=
[
"collective"
,
"ps"
,
"ps_heter"
,
"ps_gpu"
,
""
],
default
=
""
,
help
=
"Distributed running mode: collective/ps/ps_gpu/ps_heter"
)
base_group
.
add_argument
(
"--log_dir"
,
type
=
str
,
default
=
"log"
,
help
=
"The path for each process's log.If it's not set, the log will printed to default pipe."
)
base_group
.
add_argument
(
"training_script"
,
type
=
str
,
help
=
"The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
)
# Optional arguments for the launch helper
parser
.
add_argument
(
# for collective
collective_group
=
parser
.
add_argument_group
(
"Collective Parameters"
)
collective_group
.
add_argument
(
"--ips"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17.."
)
parser
.
add_argument
(
collective_group
.
add_argument
(
"--gpus"
,
type
=
str
,
default
=
None
,
...
...
@@ -104,31 +130,30 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"each process is bound to a single GPU. And if it's not set, this module will use all the gpu cards for training."
)
parser
.
add_argument
(
ps_group
=
parser
.
add_argument_group
(
"Parameter-Server Parameters"
)
# for parameter server
ps_group
.
add_argument
(
"--servers"
,
type
=
str
,
default
=
""
,
help
=
"User defined servers ip:port"
)
p
arser
.
add_argument
(
p
s_group
.
add_argument
(
"--workers"
,
type
=
str
,
default
=
""
,
help
=
"User defined workers ip:port"
)
parser
.
add_argument
(
"--worker_num"
,
type
=
int
,
help
=
"number of workers"
)
ps_group
.
add_argument
(
"--heter_workers"
,
type
=
str
,
default
=
""
,
help
=
"User defined heter workers ip:port"
)
parser
.
add_argument
(
"--server_num"
,
type
=
int
,
help
=
"number of servers"
)
ps_group
.
add_argument
(
"--worker_num"
,
type
=
int
,
help
=
"number of workers"
)
ps_group
.
add_argument
(
"--server_num"
,
type
=
int
,
help
=
"number of servers"
)
ps_group
.
add_argument
(
"--heter_worker_num"
,
type
=
int
,
help
=
"number of heter_workers"
)
p
arser
.
add_argument
(
"--
log_dir
"
,
p
s_group
.
add_argument
(
"--
heter_worker_device
"
,
type
=
str
,
default
=
"log"
,
help
=
"The path for each process's log.If it's not set, the log will printed to default pipe."
)
# positional
parser
.
add_argument
(
"training_script"
,
type
=
str
,
help
=
"The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
)
default
=
"gpu"
,
choices
=
[
"gpu"
,
"xpu"
],
help
=
"heter worker device"
)
# rest from the training program
parser
.
add_argument
(
'training_script_args'
,
nargs
=
REMAINDER
)
return
parser
.
parse_args
()
...
...
@@ -246,209 +271,32 @@ def launch_collective(args):
def
launch_ps
(
args
):
ports
=
None
start_port
=
6170
if
args
.
server_num
:
server_num
=
args
.
server_num
ports
=
get_ports
(
server_num
,
0
)
server_endpoints
=
","
.
join
([
"127.0.0.1:"
+
str
(
x
)
for
x
in
ports
])
else
:
assert
args
.
servers
!=
""
,
"The setting of CPU mode must be either server_num or servers."
server_endpoints
=
args
.
servers
server_endpoints_ips
=
[
x
.
strip
().
split
(
":"
)[
0
]
for
x
in
server_endpoints
.
split
(
","
)
]
server_endpoints_port
=
[
x
.
strip
().
split
(
":"
)[
1
]
for
x
in
server_endpoints
.
split
(
","
)
]
server_num
=
len
(
server_endpoints_ips
)
cloud_flag
=
cloud_utils
.
use_paddlecloud
()
if
args
.
worker_num
:
worker_num
=
args
.
worker_num
ports
=
get_ports
(
worker_num
,
server_num
)
worker_endpoints
=
","
.
join
([
"127.0.0.1:"
+
str
(
x
)
for
x
in
ports
])
else
:
assert
args
.
workers
!=
""
,
"The setting of CPU mode must be either worker_num or workers."
worker_endpoints
=
args
.
workers
worker_endpoints_ips
=
[
x
.
strip
().
split
(
":"
)[
0
]
for
x
in
worker_endpoints
.
split
(
","
)
]
worker_num
=
len
(
worker_endpoints_ips
)
node_ips
=
list
(
set
(
server_endpoints_ips
+
worker_endpoints_ips
))
worker_endpoints_len
=
[
len
(
x
.
strip
().
split
(
":"
))
for
x
in
worker_endpoints
.
split
(
","
)
]
if
1
in
worker_endpoints_len
:
# if no port value in worker_endpoints, will set default port values.
worker_endpoints_port
=
range
(
start_port
+
server_num
,
start_port
+
server_num
+
worker_num
,
1
)
else
:
worker_endpoints_port
=
[
x
.
strip
().
split
(
":"
)[
1
]
for
x
in
worker_endpoints
.
split
(
","
)
]
# for ps-cpu on paddlecloud
direct_start_mode
=
[
"ps"
,
""
]
if
cloud_flag
and
(
args
.
distributed_mode
in
direct_start_mode
):
direct_start
(
args
)
return
elif
cloud_flag
and
args
.
distributed_mode
==
"ps_heter"
:
cloud_ps_heter_env_set
(
args
)
args
.
trainers
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
)
args
.
workers
=
os
.
getenv
(
"PADDLE_PSERVERS_IP_PORT_LIST"
)
args
.
heter_workers
=
os
.
getenv
(
"PADDLE_HETER_TRAINER_IP_PORT_LIST"
)
# local train
if
len
(
set
(
node_ips
))
==
1
:
current_node_ip
=
node_ips
[
0
]
else
:
_
,
current_node_ip
=
get_host_name_ip
()
assert
current_node_ip
in
node_ips
,
"Can't find your local ip {%s} in args.servers and args.workers ips: {%s}"
\
%
(
current_node_ip
,
node_ips
)
node_rank
=
node_ips
.
index
(
current_node_ip
)
logger
.
debug
(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}"
.
format
(
node_ips
,
current_node_ip
,
node_rank
,
server_endpoints_port
))
cluster
=
Cluster
(
hdfs
=
None
)
server_rank
=
0
worker_rank
=
0
for
node_rank
,
ip
in
enumerate
(
node_ips
):
pod
=
Pod
()
pod
.
rank
=
node_rank
pod
.
addr
=
ip
for
i
in
range
(
len
(
server_endpoints_ips
)):
if
ip
==
server_endpoints_ips
[
i
]:
server
=
Trainer
()
server
.
endpoint
=
"%s:%s"
%
(
ip
,
server_endpoints_port
[
i
])
server
.
rank
=
server_rank
server_rank
+=
1
pod
.
servers
.
append
(
server
)
for
j
in
range
(
len
(
worker_endpoints_ips
)):
if
ip
==
worker_endpoints_ips
[
j
]:
worker
=
Trainer
()
worker
.
endpoint
=
"%s:%s"
%
(
ip
,
worker_endpoints_port
[
i
])
worker
.
rank
=
worker_rank
worker_rank
+=
1
pod
.
workers
.
append
(
worker
)
cluster
.
pods
.
append
(
pod
)
pod_rank
=
node_ips
.
index
(
current_node_ip
)
pod
=
cluster
.
pods
[
pod_rank
]
default_env
=
os
.
environ
.
copy
()
current_env
=
copy
.
copy
(
default_env
)
gloo_rendezvous_dir
=
tempfile
.
mkdtemp
()
# add gloo env
current_env
[
"PADDLE_WITH_GLOO"
]
=
"1"
current_env
[
"PADDLE_GLOO_RENDEZVOUS"
]
=
"2"
current_env
[
"PADDLE_GLOO_FS_PATH"
]
=
gloo_rendezvous_dir
current_env
.
pop
(
"http_proxy"
,
None
)
current_env
.
pop
(
"https_proxy"
,
None
)
procs
=
[]
cmds
=
[]
log_fns
=
[]
for
idx
,
cur_server
in
enumerate
(
pod
.
servers
):
proc_env
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
server_endpoints
,
"PADDLE_TRAINER_ENDPOINTS"
:
worker_endpoints
,
"PADDLE_PORT"
:
cur_server
.
endpoint
.
split
(
":"
)[
1
],
"TRAINING_ROLE"
:
"PSERVER"
,
"PADDLE_TRAINERS_NUM"
:
str
(
worker_num
),
"POD_IP"
:
cur_server
.
endpoint
.
split
(
":"
)[
0
],
"PADDLE_WITH_GLOO"
:
"1"
}
current_env
.
update
(
proc_env
)
cmd
=
[
sys
.
executable
,
"-u"
,
args
.
training_script
]
+
args
.
training_script_args
cmds
.
append
(
cmd
)
if
idx
==
0
:
logger
.
info
(
"Local server start {} processes. First process distributed "
"environment info (Only For Debug): {}"
.
format
(
len
(
pod
.
servers
),
pretty_print_envs
(
proc_env
,
(
"Distributed Envs"
,
"Value"
))))
if
args
.
log_dir
is
not
None
:
os
.
system
(
"mkdir -p {}"
.
format
(
args
.
log_dir
))
fn
=
open
(
"%s/serverlog.%d"
%
(
args
.
log_dir
,
idx
),
"w"
)
log_fns
.
append
(
fn
)
proc
=
subprocess
.
Popen
(
cmd
,
env
=
current_env
,
stdout
=
fn
,
stderr
=
fn
)
else
:
proc
=
subprocess
.
Popen
(
cmd
,
env
=
current_env
)
tp
=
TrainerProc
()
tp
.
proc
=
proc
tp
.
rank
=
cur_server
.
rank
tp
.
local_rank
=
idx
tp
.
log_fn
=
fn
tp
.
log_offset
=
fn
.
tell
()
if
fn
else
None
tp
.
cmd
=
cmd
procs
.
append
(
tp
)
for
idx
,
cur_worker
in
enumerate
(
pod
.
workers
):
proc_env
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
server_endpoints
,
"PADDLE_TRAINER_ENDPOINTS"
:
worker_endpoints
,
"PADDLE_TRAINERS_NUM"
:
str
(
worker_num
),
"TRAINING_ROLE"
:
"TRAINER"
,
"PADDLE_TRAINER_ID"
:
str
(
cur_worker
.
rank
),
"PADDLE_WITH_GLOO"
:
"1"
}
current_env
.
update
(
proc_env
)
cmd
=
[
sys
.
executable
,
"-u"
,
args
.
training_script
]
+
args
.
training_script_args
cmds
.
append
(
cmd
)
if
idx
==
0
:
logger
.
info
(
"Local worker start {} processes. First process distributed "
"environment info (Only For Debug): {}"
.
format
(
len
(
pod
.
workers
),
pretty_print_envs
(
proc_env
,
(
"Distributed Envs"
,
"Value"
))))
if
args
.
log_dir
is
not
None
:
os
.
system
(
"mkdir -p {}"
.
format
(
args
.
log_dir
))
fn
=
open
(
"%s/workerlog.%d"
%
(
args
.
log_dir
,
idx
),
"w"
)
log_fns
.
append
(
fn
)
proc
=
subprocess
.
Popen
(
cmd
,
env
=
current_env
,
stdout
=
fn
,
stderr
=
fn
)
else
:
proc
=
subprocess
.
Popen
(
cmd
,
env
=
current_env
)
tp
=
TrainerProc
()
tp
.
proc
=
proc
tp
.
rank
=
cur_worker
.
rank
tp
.
local_rank
=
idx
tp
.
log_fn
=
fn
tp
.
log_offset
=
fn
.
tell
()
if
fn
else
None
tp
.
cmd
=
cmd
procs
.
append
(
tp
)
logger
.
info
(
"Please check servers and workers logs in {}/workerlog.* and {}/serverlog.*"
.
format
(
args
.
log_dir
,
args
.
log_dir
))
# only wait worker to finish here
for
i
,
proc
in
enumerate
(
procs
):
if
i
<
len
(
pod
.
servers
):
continue
procs
[
i
].
proc
.
wait
()
if
len
(
log_fns
)
>
0
:
log_fns
[
i
].
close
()
print
(
"all workers exit, going to finish parameter server"
,
file
=
sys
.
stderr
)
for
i
in
range
(
len
(
pod
.
servers
)):
if
len
(
log_fns
)
>
0
:
log_fns
[
i
].
close
()
procs
[
i
].
proc
.
terminate
()
print
(
"all parameter server are killed"
,
file
=
sys
.
stderr
)
if
os
.
path
.
exists
(
gloo_rendezvous_dir
):
shutil
.
rmtree
(
gloo_rendezvous_dir
)
ps_launcher
=
ParameterServerLauncher
(
args
)
ps_launcher
.
start_ps
(
args
)
return
def
launch
():
args
=
_parse_args
()
logger
=
get_logger
()
_print_arguments
(
args
)
ps_args
=
[
'--worker_num'
,
'--server_num'
,
'--servers'
,
'--workers'
]
ps_args
=
[
'--worker_num'
,
'--server_num'
,
'--heter_worker_num'
,
'--servers'
,
'--workers'
,
'--heter_worrkers'
,
'heter_worker_device'
]
collective_args
=
[
'--ips'
,
'--gpus'
]
has_ps_args
=
[
ps_arg
for
ps_arg
in
ps_args
if
ps_arg
in
" "
.
join
(
sys
.
argv
[
1
:
-
1
])
...
...
@@ -462,9 +310,10 @@ def launch():
else
:
cuda_device_num
=
0
if
len
(
has_ps_args
)
>
0
or
cuda_device_num
==
0
:
ps_mode
=
[
'ps'
,
'ps_gpu'
,
'ps_heter'
]
if
len
(
has_ps_args
)
>
0
or
args
.
distributed_mode
in
ps_mode
:
logger
.
info
(
"Run parameter-sever
cpu
mode. pserver arguments:{}, cuda count:{}"
.
"Run parameter-sever mode. pserver arguments:{}, cuda count:{}"
.
format
(
has_ps_args
,
cuda_device_num
))
launch_ps
(
args
)
elif
len
(
has_collective_args
)
>
0
:
...
...
python/paddle/distributed/fleet/launch_utils.py
浏览文件 @
34bd0045
此差异已折叠。
点击以展开。
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
浏览文件 @
34bd0045
...
...
@@ -13,6 +13,7 @@
from
paddle
import
fluid
from
.meta_optimizer_base
import
MetaOptimizerBase
from
..base.private_helper_function
import
wait_server_ready
from
paddle.fluid
import
core
import
subprocess
import
re
...
...
@@ -74,6 +75,8 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_startup
=
worker
.
delet_extra_optimizes_pass
(
_startup
,
compiled_config
)
compiled_config
.
set_origin_ps_main_program
(
_main
)
compiled_config
.
set_origin_ps_startup_program
(
_startup
)
# for heter program
if
self
.
role_maker
.
_is_heter_parameter_server_mode
:
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
heter_trainer_pass
as
heter_worker
...
...
@@ -91,6 +94,16 @@ class ParameterServerOptimizer(MetaOptimizerBase):
else
:
_main
=
worker
.
append_send_ops_pass
(
_main
,
compiled_config
)
_startup
=
_startup
compiled_config
.
set_origin_ps_main_program
(
_main
)
compiled_config
.
set_origin_ps_startup_program
(
_startup
)
# for trainer wait server ready
wait_server_ready
(
self
.
role_maker
.
_get_pserver_endpoints
())
# for ps-heter mode, wait heter worker ready
if
self
.
role_maker
.
_is_heter_parameter_server_mode
and
self
.
role_maker
.
_is_worker
(
):
wait_server_ready
(
self
.
role_maker
.
_get_heter_worker_endpoints
())
return
_main
,
_startup
...
...
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
浏览文件 @
34bd0045
...
...
@@ -458,13 +458,13 @@ class ParameterServerRuntime(RuntimeBase):
def
_save_distributed_persistables
(
self
,
executor
,
dirname
,
main_program
):
dense_ctx
=
self
.
compiled_strategy
.
get_communicator_recv_context
(
recv_type
=
1
)
recv_type
=
1
,
use_origin_program
=
True
)
sparse_ctx
=
self
.
compiled_strategy
.
get_communicator_recv_context
(
recv_type
=
2
)
recv_type
=
2
,
use_origin_program
=
True
)
distributed_ctx
=
self
.
compiled_strategy
.
get_communicator_recv_context
(
recv_type
=
3
)
recv_type
=
3
,
use_origin_program
=
True
)
recv_dense_varnames
=
self
.
_save_dense_params
(
executor
,
dirname
,
dense_ctx
,
main_program
)
...
...
@@ -516,7 +516,7 @@ class ParameterServerRuntime(RuntimeBase):
)
if
main_program
is
None
:
main_program
=
fluid
.
default
_main_program
()
main_program
=
self
.
compiled_strategy
.
get_origin_ps
_main_program
()
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
...
...
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
浏览文件 @
34bd0045
...
...
@@ -133,6 +133,8 @@ class CompileTimeStrategy(object):
self
.
origin_main_program
=
main_program
self
.
origin_startup_program
=
startup_program
self
.
origin_ps_main_program
=
main_program
self
.
origin_ps_startup_program
=
startup_program
self
.
strategy
=
strategy
self
.
role_maker
=
role_maker
...
...
@@ -153,6 +155,11 @@ class CompileTimeStrategy(object):
self
.
_build_var_distributed
()
# for heter-ps save variables
self
.
origin_merged_variables_pairs
=
list
(
self
.
merged_variables_pairs
)
self
.
origin_merged_dense_pairs
=
list
(
self
.
merged_dense_pairs
)
self
.
origin_merged_sparse_pairs
=
list
(
self
.
merged_sparse_pairs
)
def
get_distributed_mode
(
self
):
trainer
=
self
.
strategy
.
get_trainer_runtime_config
()
return
trainer
.
mode
...
...
@@ -214,6 +221,18 @@ class CompileTimeStrategy(object):
def
get_origin_startup_program
(
self
):
return
self
.
origin_startup_program
def
set_origin_ps_main_program
(
self
,
program
):
self
.
origin_ps_main_program
=
program
def
set_origin_ps_startup_program
(
self
,
program
):
self
.
origin_ps_startup_program
=
program
def
get_origin_ps_main_program
(
self
):
return
self
.
origin_ps_main_program
def
get_origin_ps_startup_program
(
self
):
return
self
.
origin_ps_startup_program
def
get_sparse_varname_on_ps
(
self
,
is_distributed
,
endpoint
=
None
):
if
not
endpoint
:
endpoint
=
self
.
get_ps_endpoint
()
...
...
@@ -378,7 +397,9 @@ class CompileTimeStrategy(object):
send_ctx
[
name
]
=
ctx
return
send_ctx
def
get_communicator_recv_context
(
self
,
recv_type
=
1
):
def
get_communicator_recv_context
(
self
,
recv_type
=
1
,
use_origin_program
=
False
):
# recv_type
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
...
...
@@ -392,7 +413,8 @@ class CompileTimeStrategy(object):
sparse_recv_ctx
=
{}
distributed_recv_ctx
=
{}
for
merged
in
self
.
merged_variables_pairs
:
variables_pairs
=
self
.
merged_variables_pairs
if
not
use_origin_program
else
self
.
origin_merged_variables_pairs
for
merged
in
variables_pairs
:
params
=
merged
[
0
]
if
params
.
merged_var
.
name
in
sparse_varnames
:
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录