Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
66a31501
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
66a31501
编写于
2月 22, 2020
作者:
T
tangwei12
提交者:
GitHub
2月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
SYNC with communicaotor (#22344)
* add sync communicator and implement
上级
22bbd547
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
192 addition
and
27 deletion
+192
-27
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+0
-2
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+51
-7
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+27
-2
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+5
-0
paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc
paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc
+1
-0
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+4
-0
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+4
-0
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+12
-5
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+6
-3
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/test_communicator_sync.py
...on/paddle/fluid/tests/unittests/test_communicator_sync.py
+64
-0
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+3
-4
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+12
-4
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
66a31501
...
...
@@ -180,8 +180,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
if
(
places_
.
size
()
==
1
)
{
exception_holder_
.
Clear
();
}
else
{
HandleException
();
}
FeedFetchList
fetch_data
;
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
66a31501
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
namespace
paddle
{
...
...
@@ -64,7 +65,6 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void
AsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"AsyncCommunicator Initializing"
;
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
...
...
@@ -90,7 +90,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
void
AsyncCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
VLOG
(
0
)
<<
"AsyncCommunicator Initializing"
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
...
...
@@ -332,8 +331,6 @@ GeoSgdCommunicator::~GeoSgdCommunicator() {
void
GeoSgdCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"GeoCommunicator Initializing"
;
training_scope_
=
std
::
move
(
recv_scope
);
auto
geo_send_varnames
=
envs
[
"geo_send_varnames"
];
...
...
@@ -954,7 +951,6 @@ void GeoSgdCommunicator::Recv() {}
void
HalfAsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"HalfAsyncCommunicator Initializing"
;
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
...
...
@@ -1011,6 +1007,8 @@ void HalfAsyncCommunicator::InitImpl(
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{},
trainer_id
);
VLOG
(
3
)
<<
"find and init an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
}
}
...
...
@@ -1032,7 +1030,8 @@ void HalfAsyncCommunicator::ConsumeThread() {
VLOG
(
3
)
<<
"ConsumeThread start!"
;
while
(
running_
)
{
while
(
running_
)
{
if
(
barrier_counter_
.
load
()
>=
barrier_trigger_
.
load
())
{
if
(
barrier_counter_
.
load
()
>=
barrier_trigger_
.
load
()
&&
barrier_trigger_
.
load
()
!=
0
)
{
break
;
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
@@ -1096,8 +1095,10 @@ void HalfAsyncCommunicator::ConsumeThread() {
VLOG
(
3
)
<<
"run send graph use time "
<<
after_run_send_graph
-
before_run_send_graph
;
Recv
();
BarrierSend
();
Recv
();
BarrierRecv
();
BarrierWeakUp
();
}
VLOG
(
0
)
<<
"communicator stopped, send thread exit"
;
...
...
@@ -1200,6 +1201,49 @@ void HalfAsyncCommunicator::Stop() {
VLOG
(
0
)
<<
"Communicator stop done"
;
}
void
SyncCommunicator
::
BarrierSend
()
{
if
(
!
running_
)
return
;
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id_
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
auto
&
ep
:
pserver_endpoints_
)
{
rets
.
push_back
(
rpc_client
->
AsyncSendBatchBarrier
(
ep
));
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
rets
[
i
]
->
Wait
(),
0U
,
platform
::
errors
::
External
(
"internal error in RPCClient"
));
}
VLOG
(
4
)
<<
"BarrierSend with SyncCommunicator"
;
}
void
SyncCommunicator
::
BarrierRecv
()
{
if
(
!
running_
)
return
;
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id_
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
auto
&
ep
:
pserver_endpoints_
)
{
rets
.
push_back
(
rpc_client
->
AsyncSendFetchBarrier
(
ep
));
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
rets
[
i
]
->
Wait
(),
0U
,
platform
::
errors
::
External
(
"internal error in RPCClient"
));
}
VLOG
(
4
)
<<
"BarrierRecv with SyncCommunicator"
;
}
SyncCommunicator
::~
SyncCommunicator
()
{
running_
=
false
;
if
(
consume_thread_
)
consume_thread_
->
join
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
66a31501
...
...
@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
DECLARE_bool
(
communicator_is_sgd_optimizer
);
...
...
@@ -246,6 +247,7 @@ class AsyncCommunicator : public Communicator {
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
is_sgd_optimizer_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"communicator_is_sgd_optimizer"
)));
VLOG
(
0
)
<<
"AsyncCommunicator Initialized"
;
}
~
AsyncCommunicator
();
void
Start
()
override
;
...
...
@@ -301,6 +303,7 @@ class HalfAsyncCommunicator : public Communicator {
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
VLOG
(
0
)
<<
"HalfAsyncCommunicator Initialized"
;
}
~
HalfAsyncCommunicator
();
void
Start
()
override
;
...
...
@@ -326,14 +329,17 @@ class HalfAsyncCommunicator : public Communicator {
Scope
*
recv_scope
)
override
;
void
ConsumeThread
();
virtual
void
BarrierSend
()
{}
virtual
void
BarrierRecv
()
{}
pr
ivate
:
pr
otected
:
int
max_merge_var_num_
;
int
send_wait_times_
;
int
thread_pool_size_
;
int
send_queue_size_
;
int
trainer_id_
=
0
;
pr
ivate
:
pr
otected
:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
send_varname_to_queue_
;
...
...
@@ -352,6 +358,24 @@ class HalfAsyncCommunicator : public Communicator {
std
::
atomic
<
int64_t
>
barrier_counter_
{
0
};
};
class
SyncCommunicator
:
public
HalfAsyncCommunicator
{
public:
SyncCommunicator
()
:
HalfAsyncCommunicator
()
{}
explicit
SyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
envs
)
:
HalfAsyncCommunicator
(
envs
)
{
trainer_id_
=
std
::
stoi
(
envs
.
at
(
"trainer_id"
));
auto
pserver_strings
=
envs
.
at
(
"pserver_endpoints"
);
pserver_endpoints_
=
paddle
::
string
::
Split
(
pserver_strings
,
','
);
VLOG
(
0
)
<<
"SyncCommunicator Initialized"
;
}
~
SyncCommunicator
();
void
BarrierSend
();
void
BarrierRecv
();
private:
std
::
vector
<
std
::
string
>
pserver_endpoints_
{};
};
class
GeoSgdCommunicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
:
Communicator
()
{}
...
...
@@ -361,6 +385,7 @@ class GeoSgdCommunicator : public Communicator {
trainer_nums_
=
std
::
stoi
(
envs
.
at
(
"geo_trainer_nums"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
VLOG
(
0
)
<<
"GeoSgdCommunicator Initialized"
;
}
~
GeoSgdCommunicator
();
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
66a31501
...
...
@@ -115,6 +115,11 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
*
out
=
send_tensor
.
Slice
(
row_offset
,
row_offset
+
outs_dims
[
i
][
0
]);
row_offset
+=
outs_dims
[
i
][
0
];
}
}
else
{
auto
&
send_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
framework
::
Tensor
*
out
=
local_scope
->
Var
(
rpc_ctx
.
splited_var_names
[
0
])
->
GetMutable
<
framework
::
LoDTensor
>
();
out
->
ShareDataWith
(
send_tensor
);
}
if
(
rpc_ctx
.
use_send_handler
)
{
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
...
...
paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc
浏览文件 @
66a31501
...
...
@@ -36,6 +36,7 @@ class FetchBarrierOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
...
...
paddle/fluid/pybind/communicator_py.cc
浏览文件 @
66a31501
...
...
@@ -32,6 +32,7 @@ using paddle::operators::distributed::AsyncCommunicator;
using
paddle
::
operators
::
distributed
::
Communicator
;
using
paddle
::
operators
::
distributed
::
GeoSgdCommunicator
;
using
paddle
::
operators
::
distributed
::
HalfAsyncCommunicator
;
using
paddle
::
operators
::
distributed
::
SyncCommunicator
;
namespace
paddle
{
namespace
pybind
{
...
...
@@ -52,6 +53,9 @@ void BindCommunicator(py::module* m) {
}
else
if
(
mode
==
"GEO"
)
{
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
param_scope
,
envs
);
}
else
if
(
mode
==
"SYNC"
)
{
Communicator
::
InitInstance
<
SyncCommunicator
>
(
program
,
param_scope
,
envs
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unsuported communicator MODE"
));
...
...
python/paddle/fluid/communicator.py
浏览文件 @
66a31501
...
...
@@ -70,6 +70,10 @@ class Communicator(object):
envs
[
"geo_need_push_nums"
]
=
str
(
kwargs
[
"push_nums"
])
envs
[
"geo_send_varnames"
]
=
'#'
.
join
(
push_var_names
)
if
mode
==
DistributedMode
.
SYNC
:
envs
[
"pserver_endpoints"
]
=
','
.
join
(
kwargs
[
"pserver_endpoints"
])
envs
[
"trainer_id"
]
=
str
(
kwargs
[
"trainer_id"
])
mode_str
=
None
if
mode
==
DistributedMode
.
SYNC
:
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
66a31501
...
...
@@ -73,9 +73,6 @@ class DistributedTranspiler(Fleet):
trainer_communicator_config
=
self
.
_transpile_config
.
get_trainer_runtime_config
(
)
if
isinstance
(
self
.
_transpile_config
,
SyncStrategy
):
return
print
(
trainer_communicator_config
)
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
):
...
...
@@ -98,6 +95,17 @@ class DistributedTranspiler(Fleet):
self
.
_communicator
=
Communicator
(
self
.
main_program
,
DistributedMode
.
HALF_ASYNC
,
None
,
trainer_communicator_config
.
get_communicator_flags
())
elif
isinstance
(
self
.
_transpile_config
,
SyncStrategy
):
kwargs
=
{}
kwargs
[
"pserver_endpoints"
]
=
self
.
_role_maker
.
get_pserver_endpoints
()
kwargs
[
"trainer_id"
]
=
self
.
_role_maker
.
worker_index
()
self
.
_communicator
=
Communicator
(
self
.
main_program
,
DistributedMode
.
SYNC
,
kwargs
,
trainer_communicator_config
.
get_communicator_flags
())
else
:
raise
TypeError
(
"Training MODE do not supported"
)
...
...
@@ -156,8 +164,7 @@ class DistributedTranspiler(Fleet):
None
"""
if
not
isinstance
(
self
.
_transpile_config
,
SyncStrategy
):
self
.
_communicator
.
stop
()
self
.
_communicator
.
stop
()
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
self
.
_role_maker
.
_finalize
()
self
.
_executor
.
close
()
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
66a31501
...
...
@@ -181,9 +181,12 @@ class DistributedStrategy(object):
class
SyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
SyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
True
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
half_async
=
True
self
.
_program_config
.
completely_not_async
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
66a31501
...
...
@@ -26,6 +26,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_launch_ps)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_async
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_geo
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_half_async
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_sync
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_api_input
)
foreach
(
TEST_OP
${
MIXED_DIST_TEST_OPS
}
)
list
(
REMOVE_ITEM TEST_OPS
${
TEST_OP
}
)
...
...
@@ -284,6 +285,8 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_communicator_async MODULES test_communicator_async ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_communicator_geo MODULES test_communicator_geo ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_communicator_half_async MODULES test_communicator_half_async ENVS
${
dist_ENVS
}
FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1
)
py_test_modules
(
test_communicator_sync MODULES test_communicator_sync ENVS
${
dist_ENVS
}
FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1
)
if
(
WITH_DGC
)
# if with dgc, test all dgc tests.
# NOTE. dist dgc tests is already in DIST_TEST_OPS
...
...
python/paddle/fluid/tests/unittests/test_communicator_sync.py
0 → 100644
浏览文件 @
66a31501
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
time
import
threading
import
numpy
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.communicator
import
Communicator
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
StrategyFactory
class
TestCommunicator
(
unittest
.
TestCase
):
def
net
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
return
avg_cost
def
test_communicator_sync
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:6001"
,
"127.0.0.1:6002"
])
fleet
.
init
(
role
)
avg_cost
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
0.01
)
strategy
=
StrategyFactory
.
create_sync_strategy
()
strategy
.
_program_config
.
wait_port
=
False
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
fleet
.
init_worker
()
time
.
sleep
(
10
)
fleet
.
stop_worker
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
66a31501
...
...
@@ -25,10 +25,9 @@ class TestStrategyFactor(unittest.TestCase):
def
test_sync_strategy
(
self
):
os
.
environ
[
'CPU_NUM'
]
=
"2"
strategy
=
StrategyFactory
.
create_sync_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
self
.
assertEqual
(
strategy
.
_execute_strategy
.
num_threads
,
2
)
# test set_program_config using DistributeTranspilerConfig()
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
66a31501
...
...
@@ -192,6 +192,7 @@ class DistributeTranspilerConfig(object):
# half_async
half_async
=
False
completely_not_async
=
False
# Geo-sgd algorithm
geo_sgd_mode
=
False
...
...
@@ -323,7 +324,7 @@ class DistributeTranspiler(object):
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
if
self
.
config
.
sync_mode
:
if
self
.
config
.
sync_mode
or
self
.
config
.
completely_not_async
:
self
.
distributed_mode
=
DistributedMode
.
SYNC
elif
self
.
config
.
runtime_split_send_recv
:
self
.
distributed_mode
=
DistributedMode
.
ASYNC
...
...
@@ -728,7 +729,14 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
program
.
global_block
().
vars
[
splited_grad_varname
]
]
sections
=
self
.
_get_splited_var_sections
(
splited_vars
)
send_varnames
=
[
var
.
name
for
var
in
splited_vars
]
if
self
.
config
.
completely_not_async
:
send_varnames
=
[
"{}.trainer_{}"
.
format
(
var
.
name
,
self
.
trainer_id
)
for
var
in
splited_vars
]
else
:
send_varnames
=
[
var
.
name
for
var
in
splited_vars
]
else
:
send_input_vars
=
splited_vars
sections
=
[]
...
...
@@ -1199,7 +1207,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
type
=
v
.
type
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
if
self
.
sync_mode
or
self
.
config
.
completely_not_async
and
self
.
trainer_num
>
1
:
for
trainer_id
in
range
(
self
.
trainer_num
):
var
=
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
orig_var_name
,
trainer_id
),
...
...
@@ -2204,7 +2212,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
merged_var
=
pserver_block
.
vars
[
merged_var_name
]
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
if
self
.
sync_mode
or
self
.
config
.
completely_not_async
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
for
i
in
range
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录