未验证 提交 66a31501 编写于 作者: T tangwei12 提交者: GitHub

SYNC with communicaotor (#22344)

* add sync communicator and implement
上级 22bbd547
...@@ -180,8 +180,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run( ...@@ -180,8 +180,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
if (places_.size() == 1) { if (places_.size() == 1) {
exception_holder_.Clear(); exception_holder_.Clear();
} else {
HandleException();
} }
FeedFetchList fetch_data; FeedFetchList fetch_data;
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h" #include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
...@@ -64,7 +65,6 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr); ...@@ -64,7 +65,6 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) { Scope *recv_scope) {
VLOG(0) << "AsyncCommunicator Initializing";
send_varname_to_ctx_ = std::move(send_varname_to_ctx); send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope); recv_scope_ = std::move(recv_scope);
...@@ -90,7 +90,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -90,7 +90,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
Scope *param_scope) { Scope *param_scope) {
VLOG(0) << "AsyncCommunicator Initializing";
RpcCtxMap send_varname_to_ctx; RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx; RpcCtxMap recv_varname_to_ctx;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
...@@ -332,8 +331,6 @@ GeoSgdCommunicator::~GeoSgdCommunicator() { ...@@ -332,8 +331,6 @@ GeoSgdCommunicator::~GeoSgdCommunicator() {
void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
Scope *recv_scope) { Scope *recv_scope) {
VLOG(0) << "GeoCommunicator Initializing";
training_scope_ = std::move(recv_scope); training_scope_ = std::move(recv_scope);
auto geo_send_varnames = envs["geo_send_varnames"]; auto geo_send_varnames = envs["geo_send_varnames"];
...@@ -954,7 +951,6 @@ void GeoSgdCommunicator::Recv() {} ...@@ -954,7 +951,6 @@ void GeoSgdCommunicator::Recv() {}
void HalfAsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, void HalfAsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) { Scope *recv_scope) {
VLOG(0) << "HalfAsyncCommunicator Initializing";
send_varname_to_ctx_ = std::move(send_varname_to_ctx); send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope); recv_scope_ = std::move(recv_scope);
...@@ -1011,6 +1007,8 @@ void HalfAsyncCommunicator::InitImpl( ...@@ -1011,6 +1007,8 @@ void HalfAsyncCommunicator::InitImpl(
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id")); auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id); recv_var_name, recv_varnames, epmap, {}, trainer_id);
VLOG(3) << "find and init an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} }
} }
...@@ -1032,7 +1030,8 @@ void HalfAsyncCommunicator::ConsumeThread() { ...@@ -1032,7 +1030,8 @@ void HalfAsyncCommunicator::ConsumeThread() {
VLOG(3) << "ConsumeThread start!"; VLOG(3) << "ConsumeThread start!";
while (running_) { while (running_) {
while (running_) { while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load()) { if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
break; break;
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
...@@ -1096,8 +1095,10 @@ void HalfAsyncCommunicator::ConsumeThread() { ...@@ -1096,8 +1095,10 @@ void HalfAsyncCommunicator::ConsumeThread() {
VLOG(3) << "run send graph use time " VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph; << after_run_send_graph - before_run_send_graph;
Recv();
BarrierSend();
Recv();
BarrierRecv();
BarrierWeakUp(); BarrierWeakUp();
} }
VLOG(0) << "communicator stopped, send thread exit"; VLOG(0) << "communicator stopped, send thread exit";
...@@ -1200,6 +1201,49 @@ void HalfAsyncCommunicator::Stop() { ...@@ -1200,6 +1201,49 @@ void HalfAsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop done"; VLOG(0) << "Communicator stop done";
} }
void SyncCommunicator::BarrierSend() {
if (!running_) return;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
std::vector<distributed::VarHandlePtr> rets;
for (auto &ep : pserver_endpoints_) {
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierSend with SyncCommunicator";
}
void SyncCommunicator::BarrierRecv() {
if (!running_) return;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
std::vector<distributed::VarHandlePtr> rets;
for (auto &ep : pserver_endpoints_) {
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
SyncCommunicator::~SyncCommunicator() {
running_ = false;
if (consume_thread_) consume_thread_->join();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -37,6 +37,7 @@ limitations under the License. */ ...@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
DECLARE_bool(communicator_is_sgd_optimizer); DECLARE_bool(communicator_is_sgd_optimizer);
...@@ -246,6 +247,7 @@ class AsyncCommunicator : public Communicator { ...@@ -246,6 +247,7 @@ class AsyncCommunicator : public Communicator {
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
is_sgd_optimizer_ = is_sgd_optimizer_ =
static_cast<bool>(std::stoi(envs.at("communicator_is_sgd_optimizer"))); static_cast<bool>(std::stoi(envs.at("communicator_is_sgd_optimizer")));
VLOG(0) << "AsyncCommunicator Initialized";
} }
~AsyncCommunicator(); ~AsyncCommunicator();
void Start() override; void Start() override;
...@@ -301,6 +303,7 @@ class HalfAsyncCommunicator : public Communicator { ...@@ -301,6 +303,7 @@ class HalfAsyncCommunicator : public Communicator {
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
VLOG(0) << "HalfAsyncCommunicator Initialized";
} }
~HalfAsyncCommunicator(); ~HalfAsyncCommunicator();
void Start() override; void Start() override;
...@@ -326,14 +329,17 @@ class HalfAsyncCommunicator : public Communicator { ...@@ -326,14 +329,17 @@ class HalfAsyncCommunicator : public Communicator {
Scope* recv_scope) override; Scope* recv_scope) override;
void ConsumeThread(); void ConsumeThread();
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
private: protected:
int max_merge_var_num_; int max_merge_var_num_;
int send_wait_times_; int send_wait_times_;
int thread_pool_size_; int thread_pool_size_;
int send_queue_size_; int send_queue_size_;
int trainer_id_ = 0;
private: protected:
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>> std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_; send_varname_to_queue_;
...@@ -352,6 +358,24 @@ class HalfAsyncCommunicator : public Communicator { ...@@ -352,6 +358,24 @@ class HalfAsyncCommunicator : public Communicator {
std::atomic<int64_t> barrier_counter_{0}; std::atomic<int64_t> barrier_counter_{0};
}; };
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string>& envs)
: HalfAsyncCommunicator(envs) {
trainer_id_ = std::stoi(envs.at("trainer_id"));
auto pserver_strings = envs.at("pserver_endpoints");
pserver_endpoints_ = paddle::string::Split(pserver_strings, ',');
VLOG(0) << "SyncCommunicator Initialized";
}
~SyncCommunicator();
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoSgdCommunicator : public Communicator { class GeoSgdCommunicator : public Communicator {
public: public:
GeoSgdCommunicator() : Communicator() {} GeoSgdCommunicator() : Communicator() {}
...@@ -361,6 +385,7 @@ class GeoSgdCommunicator : public Communicator { ...@@ -361,6 +385,7 @@ class GeoSgdCommunicator : public Communicator {
trainer_nums_ = std::stoi(envs.at("geo_trainer_nums")); trainer_nums_ = std::stoi(envs.at("geo_trainer_nums"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
VLOG(0) << "GeoSgdCommunicator Initialized";
} }
~GeoSgdCommunicator(); ~GeoSgdCommunicator();
......
...@@ -115,6 +115,11 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -115,6 +115,11 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0]; row_offset += outs_dims[i][0];
} }
} else {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[0])
->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor);
} }
if (rpc_ctx.use_send_handler) { if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
......
...@@ -36,6 +36,7 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -36,6 +36,7 @@ class FetchBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
......
...@@ -32,6 +32,7 @@ using paddle::operators::distributed::AsyncCommunicator; ...@@ -32,6 +32,7 @@ using paddle::operators::distributed::AsyncCommunicator;
using paddle::operators::distributed::Communicator; using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::GeoSgdCommunicator; using paddle::operators::distributed::GeoSgdCommunicator;
using paddle::operators::distributed::HalfAsyncCommunicator; using paddle::operators::distributed::HalfAsyncCommunicator;
using paddle::operators::distributed::SyncCommunicator;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -52,6 +53,9 @@ void BindCommunicator(py::module* m) { ...@@ -52,6 +53,9 @@ void BindCommunicator(py::module* m) {
} else if (mode == "GEO") { } else if (mode == "GEO") {
Communicator::InitInstance<GeoSgdCommunicator>(program, param_scope, Communicator::InitInstance<GeoSgdCommunicator>(program, param_scope,
envs); envs);
} else if (mode == "SYNC") {
Communicator::InitInstance<SyncCommunicator>(program, param_scope,
envs);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"unsuported communicator MODE")); "unsuported communicator MODE"));
......
...@@ -70,6 +70,10 @@ class Communicator(object): ...@@ -70,6 +70,10 @@ class Communicator(object):
envs["geo_need_push_nums"] = str(kwargs["push_nums"]) envs["geo_need_push_nums"] = str(kwargs["push_nums"])
envs["geo_send_varnames"] = '#'.join(push_var_names) envs["geo_send_varnames"] = '#'.join(push_var_names)
if mode == DistributedMode.SYNC:
envs["pserver_endpoints"] = ','.join(kwargs["pserver_endpoints"])
envs["trainer_id"] = str(kwargs["trainer_id"])
mode_str = None mode_str = None
if mode == DistributedMode.SYNC: if mode == DistributedMode.SYNC:
......
...@@ -73,9 +73,6 @@ class DistributedTranspiler(Fleet): ...@@ -73,9 +73,6 @@ class DistributedTranspiler(Fleet):
trainer_communicator_config = self._transpile_config.get_trainer_runtime_config( trainer_communicator_config = self._transpile_config.get_trainer_runtime_config(
) )
if isinstance(self._transpile_config, SyncStrategy):
return
print(trainer_communicator_config) print(trainer_communicator_config)
if isinstance(self._transpile_config, GeoStrategy): if isinstance(self._transpile_config, GeoStrategy):
...@@ -98,6 +95,17 @@ class DistributedTranspiler(Fleet): ...@@ -98,6 +95,17 @@ class DistributedTranspiler(Fleet):
self._communicator = Communicator( self._communicator = Communicator(
self.main_program, DistributedMode.HALF_ASYNC, None, self.main_program, DistributedMode.HALF_ASYNC, None,
trainer_communicator_config.get_communicator_flags()) trainer_communicator_config.get_communicator_flags())
elif isinstance(self._transpile_config, SyncStrategy):
kwargs = {}
kwargs[
"pserver_endpoints"] = self._role_maker.get_pserver_endpoints()
kwargs["trainer_id"] = self._role_maker.worker_index()
self._communicator = Communicator(
self.main_program, DistributedMode.SYNC, kwargs,
trainer_communicator_config.get_communicator_flags())
else: else:
raise TypeError("Training MODE do not supported") raise TypeError("Training MODE do not supported")
...@@ -156,8 +164,7 @@ class DistributedTranspiler(Fleet): ...@@ -156,8 +164,7 @@ class DistributedTranspiler(Fleet):
None None
""" """
if not isinstance(self._transpile_config, SyncStrategy): self._communicator.stop()
self._communicator.stop()
if isinstance(self._role_maker, MPISymetricRoleMaker): if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize() self._role_maker._finalize()
self._executor.close() self._executor.close()
......
...@@ -181,9 +181,12 @@ class DistributedStrategy(object): ...@@ -181,9 +181,12 @@ class DistributedStrategy(object):
class SyncStrategy(DistributedStrategy): class SyncStrategy(DistributedStrategy):
def __init__(self): def __init__(self):
super(SyncStrategy, self).__init__() super(SyncStrategy, self).__init__()
self._program_config.sync_mode = True self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = False self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = False self._build_strategy.async_mode = True
self._program_config.half_async = True
self._program_config.completely_not_async = True
self._execute_strategy.use_thread_barrier = True
num_threads = os.getenv("CPU_NUM", "1") num_threads = os.getenv("CPU_NUM", "1")
......
...@@ -26,6 +26,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_launch_ps) ...@@ -26,6 +26,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_launch_ps)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_async) list(APPEND MIXED_DIST_TEST_OPS test_communicator_async)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_geo) list(APPEND MIXED_DIST_TEST_OPS test_communicator_geo)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_half_async) list(APPEND MIXED_DIST_TEST_OPS test_communicator_half_async)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_sync)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input) list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP}) list(REMOVE_ITEM TEST_OPS ${TEST_OP})
...@@ -284,6 +285,8 @@ if(WITH_DISTRIBUTE) ...@@ -284,6 +285,8 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_communicator_async MODULES test_communicator_async ENVS ${dist_ENVS}) py_test_modules(test_communicator_async MODULES test_communicator_async ENVS ${dist_ENVS})
py_test_modules(test_communicator_geo MODULES test_communicator_geo ENVS ${dist_ENVS}) py_test_modules(test_communicator_geo MODULES test_communicator_geo ENVS ${dist_ENVS})
py_test_modules(test_communicator_half_async MODULES test_communicator_half_async ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1) py_test_modules(test_communicator_half_async MODULES test_communicator_half_async ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1)
py_test_modules(test_communicator_sync MODULES test_communicator_sync ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1)
if(WITH_DGC) if(WITH_DGC)
# if with dgc, test all dgc tests. # if with dgc, test all dgc tests.
# NOTE. dist dgc tests is already in DIST_TEST_OPS # NOTE. dist dgc tests is already in DIST_TEST_OPS
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import time
import threading
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def test_communicator_sync(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(0.01)
strategy = StrategyFactory.create_sync_strategy()
strategy._program_config.wait_port = False
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
time.sleep(10)
fleet.stop_worker()
if __name__ == '__main__':
unittest.main()
...@@ -25,10 +25,9 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -25,10 +25,9 @@ class TestStrategyFactor(unittest.TestCase):
def test_sync_strategy(self): def test_sync_strategy(self):
os.environ['CPU_NUM'] = "2" os.environ['CPU_NUM'] = "2"
strategy = StrategyFactory.create_sync_strategy() strategy = StrategyFactory.create_sync_strategy()
self.assertEqual(strategy._program_config.sync_mode, True) self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv, self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
False) self.assertEqual(strategy._build_strategy.async_mode, True)
self.assertEqual(strategy._build_strategy.async_mode, False)
self.assertEqual(strategy._execute_strategy.num_threads, 2) self.assertEqual(strategy._execute_strategy.num_threads, 2)
# test set_program_config using DistributeTranspilerConfig() # test set_program_config using DistributeTranspilerConfig()
......
...@@ -192,6 +192,7 @@ class DistributeTranspilerConfig(object): ...@@ -192,6 +192,7 @@ class DistributeTranspilerConfig(object):
# half_async # half_async
half_async = False half_async = False
completely_not_async = False
# Geo-sgd algorithm # Geo-sgd algorithm
geo_sgd_mode = False geo_sgd_mode = False
...@@ -323,7 +324,7 @@ class DistributeTranspiler(object): ...@@ -323,7 +324,7 @@ class DistributeTranspiler(object):
if self.config.split_method is None: if self.config.split_method is None:
self.config.split_method = RoundRobin self.config.split_method = RoundRobin
if self.config.sync_mode: if self.config.sync_mode or self.config.completely_not_async:
self.distributed_mode = DistributedMode.SYNC self.distributed_mode = DistributedMode.SYNC
elif self.config.runtime_split_send_recv: elif self.config.runtime_split_send_recv:
self.distributed_mode = DistributedMode.ASYNC self.distributed_mode = DistributedMode.ASYNC
...@@ -728,7 +729,14 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -728,7 +729,14 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
program.global_block().vars[splited_grad_varname] program.global_block().vars[splited_grad_varname]
] ]
sections = self._get_splited_var_sections(splited_vars) sections = self._get_splited_var_sections(splited_vars)
send_varnames = [var.name for var in splited_vars]
if self.config.completely_not_async:
send_varnames = [
"{}.trainer_{}".format(var.name, self.trainer_id)
for var in splited_vars
]
else:
send_varnames = [var.name for var in splited_vars]
else: else:
send_input_vars = splited_vars send_input_vars = splited_vars
sections = [] sections = []
...@@ -1199,7 +1207,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -1199,7 +1207,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
if self.sync_mode and self.trainer_num > 1: if self.sync_mode or self.config.completely_not_async and self.trainer_num > 1:
for trainer_id in range(self.trainer_num): for trainer_id in range(self.trainer_num):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
...@@ -2204,7 +2212,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -2204,7 +2212,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
merged_var = pserver_block.vars[merged_var_name] merged_var = pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if self.sync_mode and self.trainer_num > 1: if self.sync_mode or self.config.completely_not_async and self.trainer_num > 1:
vars2merge = [] vars2merge = []
for i in range(self.trainer_num): for i in range(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册