Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ab18644c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ab18644c
编写于
11月 28, 2022
作者:
W
wangguanqun
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove fluid (#47959)
* remove fluid * update public * core * public * public1 * ci
上级
fd689106
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
1126 addition
and
31 deletion
+1126
-31
python/paddle/distributed/communicator.py
python/paddle/distributed/communicator.py
+269
-0
python/paddle/distributed/passes/ps_server_pass.py
python/paddle/distributed/passes/ps_server_pass.py
+15
-7
python/paddle/distributed/passes/ps_trainer_pass.py
python/paddle/distributed/passes/ps_trainer_pass.py
+3
-3
python/paddle/distributed/ps/coordinator.py
python/paddle/distributed/ps/coordinator.py
+1
-1
python/paddle/distributed/ps/the_one_ps.py
python/paddle/distributed/ps/the_one_ps.py
+11
-14
python/paddle/distributed/ps/utils/collective_transpiler.py
python/paddle/distributed/ps/utils/collective_transpiler.py
+819
-0
python/paddle/distributed/ps/utils/ps_program_builder.py
python/paddle/distributed/ps/utils/ps_program_builder.py
+6
-4
python/paddle/distributed/ps/utils/public.py
python/paddle/distributed/ps/utils/public.py
+2
-2
未找到文件。
python/paddle/distributed/communicator.py
0 → 100755
浏览文件 @
ab18644c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright(c) 2019 PaddlePaddle Authors.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
"""
import
paddle
from
paddle.framework
import
core
from
paddle.distributed.ps.utils.public
import
DistributedMode
__all__
=
[
'Communicator'
,
'FLCommunicator'
,
'LargeScaleKV'
]
class
Communicator
:
def
__init__
(
self
,
mode
,
kwargs
=
None
,
envs
=
None
):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
Args:
program(Program): the trainers program after transpile of distribute_transpiler.
It's used by communicator to extract the information to do communication.
Returns:
None
Examples:
.. code-block:: python
import paddle
prog = paddle.static.Program()
comm = paddle.distributed.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
# set all recv op to not_run mode
if
kwargs
is
None
:
if
envs
is
None
:
envs
=
{}
else
:
if
mode
==
DistributedMode
.
SYNC
:
envs
[
"pserver_endpoints"
]
=
','
.
join
(
kwargs
[
"pserver_endpoints"
]
)
envs
[
"trainers"
]
=
str
(
kwargs
[
"trainers"
])
envs
[
"trainer_id"
]
=
str
(
kwargs
[
"trainer_id"
])
envs
[
"need_global_step"
]
=
str
(
kwargs
[
"need_global_step"
])
envs
[
"barrier_table_id"
]
=
str
(
kwargs
[
"barrier_table_id"
])
mode_str
=
None
if
mode
==
DistributedMode
.
SYNC
:
mode_str
=
"SYNC"
elif
mode
==
DistributedMode
.
ASYNC
:
mode_str
=
"ASYNC"
elif
mode
==
DistributedMode
.
HALF_ASYNC
:
mode_str
=
"HALF_ASYNC"
elif
mode
==
DistributedMode
.
GEO
:
mode_str
=
"GEO"
self
.
mode
=
mode_str
self
.
envs
=
envs
self
.
communicator_
=
None
self
.
send_ctx_
=
None
self
.
recv_ctx_
=
None
def
init_with_ctx
(
self
,
send_ctx
,
recv_ctx
,
proto_txt
,
unit64_hosts
,
scope
=
None
):
if
scope
is
None
:
scope
=
paddle
.
static
.
global_scope
()
self
.
communicator_
=
core
.
DistCommunicator
(
self
.
mode
,
proto_txt
,
unit64_hosts
,
send_ctx
,
recv_ctx
,
scope
,
self
.
envs
,
)
self
.
send_ctx_
=
send_ctx
self
.
recv_ctx_
=
recv_ctx
def
create_client_to_client_connection
(
self
,
pserver_timeout_ms
=
500000
,
pserver_connect_timeout_ms
=
10000
,
max_retry
=
3
,
):
self
.
communicator_
.
create_client_to_client_connection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
)
def
get_client_info
(
self
):
return
self
.
communicator_
.
get_client_info
()
def
set_clients
(
self
,
host_list
):
self
.
communicator_
.
set_clients
(
host_list
)
def
start
(
self
):
"""
Start communicator. Should call before training process.
Returns:
None
Examples:
.. code-block:: python
import paddle
prog = paddle.static.Program()
comm = paddle.distributed.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
if
self
.
communicator_
is
None
:
print
(
'you must call init_with_ctx first to init comm before start'
)
return
self
.
communicator_
.
start
()
def
stop
(
self
):
"""
Stop communicator. Should call after training process.
Returns:
None
Examples:
.. code-block:: python
import paddle
prog = paddle.static.Program()
comm = paddle.distributed.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
if
self
.
communicator_
is
None
:
print
(
'you must call init_with_ctx first to init comm before stop'
)
return
self
.
communicator_
.
stop
()
def
is_running
(
self
):
"""
Get communicator is running or stop.
Returns:
bool
Examples:
.. code-block:: python
import paddle
prog = paddle.static.Program()
comm = paddle.distributed.communicator.Communicator(prog)
comm.is_running()
"""
if
self
.
communicator_
is
None
:
print
(
'you must call init_with_ctx first to init comm before stop'
)
return
self
.
communicator_
.
is_running
()
def
recv
(
self
):
self
.
communicator_
.
recv
()
def
init_params
(
self
,
context
):
self
.
communicator_
.
init_params
(
context
)
def
pull_dense
(
self
,
context
):
self
.
communicator_
.
pull_dense
(
context
)
def
push_sparse_param
(
self
,
var_name
,
table_id
=-
1
,
scope
=
None
):
if
scope
is
None
:
scope
=
paddle
.
static
.
global_scope
()
if
not
self
.
is_running
():
raise
ValueError
(
"Communicator should init first. Using fleet.init_worker() before push_sparse_param()"
)
assert
isinstance
(
var_name
,
str
)
assert
isinstance
(
table_id
,
int
)
if
table_id
==
-
1
:
table_id
=
self
.
send_ctx_
[
var_name
].
table_id
()
self
.
communicator_
.
push_sparse_param
(
var_name
,
table_id
,
scope
)
class
FLCommunicator
(
Communicator
):
# only for coordinator
def
__init__
(
self
,
ps_hosts
,
kwargs
=
None
):
mode
=
None
super
().
__init__
(
mode
,
kwargs
)
send_ctx
=
{}
dense_map
=
{}
prototxt
=
""
self
.
mode
=
"WITH_COORDINATOR"
self
.
init_with_ctx
(
send_ctx
,
dense_map
,
prototxt
,
ps_hosts
)
def
start_coordinator
(
self
,
self_endpoint
,
trainer_endpoints
):
if
self
.
communicator_
is
not
None
:
self
.
communicator_
.
start_coordinator
(
self_endpoint
,
trainer_endpoints
)
return
def
save_fl_strategy
(
self
,
mp
):
if
self
.
communicator_
is
not
None
:
self
.
communicator_
.
save_fl_strategy
(
mp
)
else
:
raise
ValueError
(
"self.communicator_ is null"
)
return
def
query_fl_clients_info
(
self
):
info_mp
=
{}
if
self
.
communicator_
is
not
None
:
info_mp
=
self
.
communicator_
.
query_fl_clients_info
()
return
info_mp
class
LargeScaleKV
:
def
__init__
(
self
):
self
.
scale_kv
=
core
.
LargeScaleKV
()
def
save
(
self
,
varname
,
dirname
):
self
.
scale_kv
.
save
(
varname
,
dirname
)
def
load
(
self
,
varname
,
dirname
):
self
.
scale_kv
.
load
(
varname
,
dirname
)
def
size
(
self
,
varname
):
return
self
.
scale_kv
.
size
(
varname
)
class
HeterClient
:
def
__init__
(
self
,
endpoint
,
previous_endpoint
,
trainer_id
):
self
.
heter_client_
=
core
.
HeterClient
(
endpoint
,
previous_endpoint
,
trainer_id
)
def
stop
(
self
):
self
.
heter_client_
.
stop
()
python/paddle/distributed/passes/ps_server_pass.py
浏览文件 @
ab18644c
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
logging
import
logging
import
paddle
.fluid
as
fluid
import
paddle
from
..ps.utils.public
import
(
from
..ps.utils.public
import
(
get_optimize_ops
,
get_optimize_ops
,
get_ps_endpoint
,
get_ps_endpoint
,
...
@@ -76,12 +76,14 @@ class AddLrDecayTablePass(PassBase):
...
@@ -76,12 +76,14 @@ class AddLrDecayTablePass(PassBase):
'ExponentialDecay'
,
'ExponentialDecay'
,
]
]
decay_main_program
=
fluid
.
framework
.
Program
()
decay_main_program
=
paddle
.
static
.
Program
()
decay_startup_program
=
fluid
.
framework
.
Program
()
decay_startup_program
=
paddle
.
static
.
Program
()
lr_name
=
""
lr_name
=
""
if
isinstance
(
lr_sheduler
,
ExponentialDecay
):
if
isinstance
(
lr_sheduler
,
ExponentialDecay
):
with
fluid
.
program_guard
(
decay_main_program
,
decay_startup_program
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
exponential_decay
(
lr
=
exponential_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
)
...
@@ -94,7 +96,9 @@ class AddLrDecayTablePass(PassBase):
...
@@ -94,7 +96,9 @@ class AddLrDecayTablePass(PassBase):
%
lr_decay_steps
%
lr_decay_steps
)
)
elif
isinstance
(
lr_sheduler
,
NoamDecay
):
elif
isinstance
(
lr_sheduler
,
NoamDecay
):
with
fluid
.
program_guard
(
decay_main_program
,
decay_startup_program
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
noam_decay
(
lr
=
noam_decay
(
lr_sheduler
.
d_model
,
lr_sheduler
.
warmup_steps
,
1.0
lr_sheduler
.
d_model
,
lr_sheduler
.
warmup_steps
,
1.0
)
)
...
@@ -104,7 +108,9 @@ class AddLrDecayTablePass(PassBase):
...
@@ -104,7 +108,9 @@ class AddLrDecayTablePass(PassBase):
%
lr_sheduler
.
warmup_steps
%
lr_sheduler
.
warmup_steps
)
)
elif
isinstance
(
lr_sheduler
,
NaturalExpDecay
):
elif
isinstance
(
lr_sheduler
,
NaturalExpDecay
):
with
fluid
.
program_guard
(
decay_main_program
,
decay_startup_program
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
natural_exp_decay
(
lr
=
natural_exp_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
)
...
@@ -117,7 +123,9 @@ class AddLrDecayTablePass(PassBase):
...
@@ -117,7 +123,9 @@ class AddLrDecayTablePass(PassBase):
%
lr_decay_steps
%
lr_decay_steps
)
)
elif
isinstance
(
lr_sheduler
,
InverseTimeDecay
):
elif
isinstance
(
lr_sheduler
,
InverseTimeDecay
):
with
fluid
.
program_guard
(
decay_main_program
,
decay_startup_program
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
inverse_time_decay
(
lr
=
inverse_time_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
)
...
...
python/paddle/distributed/passes/ps_trainer_pass.py
浏览文件 @
ab18644c
...
@@ -17,10 +17,10 @@ import paddle
...
@@ -17,10 +17,10 @@ import paddle
from
..ps.utils.public
import
*
# noqa: F403
from
..ps.utils.public
import
*
# noqa: F403
from
paddle.framework
import
core
from
paddle.framework
import
core
from
paddle.distributed.passes.pass_base
import
PassBase
,
register_pass
from
paddle.distributed.passes.pass_base
import
PassBase
,
register_pass
from
paddle.fluid.transpiler.details.program_utils
import
delete_ops
from
..ps.utils.collective_transpiler
import
SingleProcessMultiThread
from
paddle.fluid.transpiler.collective
import
SingleProcessMultiThread
from
_collections
import
defaultdict
from
_collections
import
defaultdict
from
paddle.fluid.framework
import
Program
,
Parameter
from
paddle.static
import
Program
from
paddle.fluid.framework
import
Parameter
@
register_pass
(
"append_send_ops_pass"
)
@
register_pass
(
"append_send_ops_pass"
)
...
...
python/paddle/distributed/ps/coordinator.py
浏览文件 @
ab18644c
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
from
paddle.
flui
d.communicator
import
FLCommunicator
from
paddle.
distribute
d.communicator
import
FLCommunicator
from
paddle.distributed.fleet.proto
import
the_one_ps_pb2
from
paddle.distributed.fleet.proto
import
the_one_ps_pb2
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
paddle.distributed.ps.utils.public
import
is_distributed_env
from
paddle.distributed.ps.utils.public
import
is_distributed_env
...
...
python/paddle/distributed/ps/the_one_ps.py
浏览文件 @
ab18644c
...
@@ -15,20 +15,17 @@
...
@@ -15,20 +15,17 @@
import
warnings
import
warnings
import
os
import
os
import
paddle
.fluid
as
fluid
import
paddle
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
paddle.f
luid
import
core
from
paddle.f
ramework
import
core
from
paddle.distributed.ps.utils.public
import
*
# noqa: F403
from
paddle.distributed.ps.utils.public
import
*
# noqa: F403
from
paddle.fluid.framework
import
Program
from
paddle.static
import
Program
,
CompiledProgram
,
Executor
,
ParallelExecutor
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.distributed.fleet.runtime.runtime_base
import
RuntimeBase
from
paddle.distributed.fleet.runtime.runtime_base
import
RuntimeBase
from
paddle.distributed.fleet.base.private_helper_function
import
(
from
paddle.distributed.fleet.base.private_helper_function
import
(
wait_server_ready
,
wait_server_ready
,
)
)
from
paddle.distributed.fleet.proto
import
the_one_ps_pb2
from
paddle.distributed.fleet.proto
import
the_one_ps_pb2
from
paddle.
flui
d.communicator
import
Communicator
,
HeterClient
from
paddle.
distribute
d.communicator
import
Communicator
,
HeterClient
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
paddle.distributed.ps.coordinator
import
Coordinator
from
paddle.distributed.ps.coordinator
import
Coordinator
...
@@ -1035,7 +1032,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1035,7 +1032,7 @@ class TheOnePSRuntime(RuntimeBase):
super
().
__init__
()
super
().
__init__
()
self
.
_communicator
=
None
self
.
_communicator
=
None
self
.
_server
=
None
self
.
_server
=
None
self
.
_worker
=
fluid
.
core
.
DistFleetWrapper
()
self
.
_worker
=
core
.
DistFleetWrapper
()
self
.
_coordinator
=
None
self
.
_coordinator
=
None
self
.
_server_sub_program
=
[]
self
.
_server_sub_program
=
[]
self
.
_heter_client
=
None
self
.
_heter_client
=
None
...
@@ -1092,7 +1089,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1092,7 +1089,7 @@ class TheOnePSRuntime(RuntimeBase):
self
.
string_hosts
=
[]
self
.
string_hosts
=
[]
for
idx
,
ep
in
enumerate
(
self
.
endpoints
):
for
idx
,
ep
in
enumerate
(
self
.
endpoints
):
host
,
port
=
ep
.
split
(
":"
)
host
,
port
=
ep
.
split
(
":"
)
pshost
=
fluid
.
core
.
PSHost
(
host
,
int
(
port
),
idx
)
pshost
=
core
.
PSHost
(
host
,
int
(
port
),
idx
)
self
.
string_hosts
.
append
(
pshost
.
serialize_to_string
())
self
.
string_hosts
.
append
(
pshost
.
serialize_to_string
())
self
.
with_coordinator
=
self
.
role_maker
.
_with_coordinator
self
.
with_coordinator
=
self
.
role_maker
.
_with_coordinator
...
@@ -1102,7 +1099,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1102,7 +1099,7 @@ class TheOnePSRuntime(RuntimeBase):
coordinator_endpoints
=
self
.
role_maker
.
_get_coordinator_endpoints
()
coordinator_endpoints
=
self
.
role_maker
.
_get_coordinator_endpoints
()
for
idx
,
ep
in
enumerate
(
coordinator_endpoints
):
for
idx
,
ep
in
enumerate
(
coordinator_endpoints
):
ip
,
port
=
ep
.
split
(
":"
)
ip
,
port
=
ep
.
split
(
":"
)
pshost
=
fluid
.
core
.
PSHost
(
ip
,
int
(
port
),
idx
)
pshost
=
core
.
PSHost
(
ip
,
int
(
port
),
idx
)
self
.
coordinator_hosts
.
append
(
pshost
.
serialize_to_string
())
self
.
coordinator_hosts
.
append
(
pshost
.
serialize_to_string
())
self
.
ps_desc_builder
=
PsDescBuilder
(
self
.
context
)
self
.
ps_desc_builder
=
PsDescBuilder
(
self
.
context
)
...
@@ -1173,7 +1170,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1173,7 +1170,7 @@ class TheOnePSRuntime(RuntimeBase):
gpus_env
=
os
.
getenv
(
"FLAGS_selected_gpus"
)
gpus_env
=
os
.
getenv
(
"FLAGS_selected_gpus"
)
gpus_env
=
[
int
(
s
)
for
s
in
gpus_env
.
split
(
","
)]
gpus_env
=
[
int
(
s
)
for
s
in
gpus_env
.
split
(
","
)]
main_program
.
_fleet_opt
[
"worker_places"
]
=
gpus_env
main_program
.
_fleet_opt
[
"worker_places"
]
=
gpus_env
PSGPU
=
fluid
.
core
.
PSGPU
()
PSGPU
=
core
.
PSGPU
()
PSGPU
.
init_gpu_ps
(
gpus_env
)
PSGPU
.
init_gpu_ps
(
gpus_env
)
def
sync_strategy_envs
():
def
sync_strategy_envs
():
...
@@ -1241,7 +1238,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1241,7 +1238,7 @@ class TheOnePSRuntime(RuntimeBase):
dense_map
,
dense_map
,
worker_desc
,
worker_desc
,
self
.
string_hosts
,
self
.
string_hosts
,
fluid
.
global_scope
(),
paddle
.
static
.
global_scope
(),
)
)
fleet
.
util
.
barrier
()
fleet
.
util
.
barrier
()
...
@@ -1273,7 +1270,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1273,7 +1270,7 @@ class TheOnePSRuntime(RuntimeBase):
raise
ValueError
(
raise
ValueError
(
"You must set the scope list when you have Multiple programs"
"You must set the scope list when you have Multiple programs"
)
)
scopes
=
[
fluid
.
global_scope
()]
scopes
=
[
paddle
.
static
.
global_scope
()]
if
len
(
self
.
origin_main_programs
)
!=
len
(
scopes
):
if
len
(
self
.
origin_main_programs
)
!=
len
(
scopes
):
raise
VauleError
(
"len(programs) != len(scopes)"
)
raise
VauleError
(
"len(programs) != len(scopes)"
)
...
@@ -1350,7 +1347,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1350,7 +1347,7 @@ class TheOnePSRuntime(RuntimeBase):
if
self
.
debug
:
if
self
.
debug
:
print
(
"server_desc:
\n
{}"
.
format
(
server_desc
))
print
(
"server_desc:
\n
{}"
.
format
(
server_desc
))
self
.
_server
=
fluid
.
core
.
DistFleetWrapper
()
self
.
_server
=
core
.
DistFleetWrapper
()
self
.
_server
.
init_server
(
self
.
_server
.
init_server
(
server_desc
,
server_desc
,
self
.
string_hosts
,
self
.
string_hosts
,
...
...
python/paddle/distributed/ps/utils/collective_transpiler.py
0 → 100644
浏览文件 @
ab18644c
此差异已折叠。
点击以展开。
python/paddle/distributed/ps/utils/ps_program_builder.py
浏览文件 @
ab18644c
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
from
.public
import
*
# noqa: F403
from
.public
import
*
# noqa: F403
from
paddle.distributed.fleet.base.private_helper_function
import
(
from
paddle.distributed.fleet.base.private_helper_function
import
(
wait_server_ready
,
wait_server_ready
,
...
@@ -77,8 +79,8 @@ class PsProgramBuilder:
...
@@ -77,8 +79,8 @@ class PsProgramBuilder:
self
.
_build_trainer_programs
()
self
.
_build_trainer_programs
()
fluid
.
framework
.
switch_startup_program
(
self
.
cloned_startup
)
fluid
.
framework
.
switch_startup_program
(
self
.
cloned_startup
)
print
(
print
(
"
fluid
.default_startup_program: {}"
.
format
(
"
paddle.static
.default_startup_program: {}"
.
format
(
fluid
.
default_startup_program
paddle
.
static
.
default_startup_program
)
)
)
)
# print("ps_program_build before =", id(self.loss.block.program))
# print("ps_program_build before =", id(self.loss.block.program))
...
@@ -471,8 +473,8 @@ class FlPsProgramBuilder(HeterAsyncPsProgramBuilder):
...
@@ -471,8 +473,8 @@ class FlPsProgramBuilder(HeterAsyncPsProgramBuilder):
fluid
.
framework
.
switch_startup_program
(
self
.
cloned_startup
)
fluid
.
framework
.
switch_startup_program
(
self
.
cloned_startup
)
fluid
.
framework
.
switch_main_program
(
self
.
cloned_main
)
fluid
.
framework
.
switch_main_program
(
self
.
cloned_main
)
print
(
print
(
"
fluid
.default_startup_program: {}"
.
format
(
"
paddle.static
.default_startup_program: {}"
.
format
(
fluid
.
default_startup_program
().
_heter_pipeline_opt
paddle
.
static
.
default_startup_program
().
_heter_pipeline_opt
)
)
)
)
else
:
else
:
...
...
python/paddle/distributed/ps/utils/public.py
浏览文件 @
ab18644c
...
@@ -19,7 +19,7 @@ import os
...
@@ -19,7 +19,7 @@ import os
import
warnings
import
warnings
import
logging
import
logging
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.f
luid
import
core
from
paddle.f
ramework
import
core
import
paddle.fluid.framework
as
framework
import
paddle.fluid.framework
as
framework
# logging.basicConfig(
# logging.basicConfig(
...
@@ -896,7 +896,7 @@ def find_heter_ops(program, default_device="cpu"):
...
@@ -896,7 +896,7 @@ def find_heter_ops(program, default_device="cpu"):
if
len
(
heter_ops
)
==
0
:
if
len
(
heter_ops
)
==
0
:
warnings
.
warn
(
warnings
.
warn
(
"No heterogeneous OP was found in your program , "
"No heterogeneous OP was found in your program , "
" please using
fluid
.device_guard() to run OPs on different device."
" please using
static
.device_guard() to run OPs on different device."
)
)
total_heter_ops
=
0
total_heter_ops
=
0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录