未验证 提交 985bceac 编写于 作者: 1 123malin 提交者: GitHub

Bug fix for sparse recorder (#21969)

* test=develop, bug fix for sparse recorder
上级 7e2af4c9
......@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true);
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
......@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(),
!request_send_h_->sync_mode());
request_send_h_->distributed_mode());
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
......
......@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(),
!request_handler->sync_mode()));
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(),
!request_handler->sync_mode()));
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......
......@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };
class RPCServer;
class VarHandle {
......@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
: sync_mode_(sync_mode),
explicit RequestHandler(int distributed_mode)
: distributed_mode_(distributed_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
......@@ -198,7 +200,7 @@ class RequestHandler {
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
bool sync_mode() { return sync_mode_; }
int distributed_mode() { return distributed_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ProgramDesc* program() { return program_; }
......@@ -225,7 +227,7 @@ class RequestHandler {
const std::string& table_name = "") = 0;
protected:
const bool sync_mode_;
const int distributed_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
......
......@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_->Complete();
} else {
// Async
if (!sync_mode_) {
if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
......@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname);
}
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
......@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (sync_mode_) {
if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
......@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
VLOG(1) << "Table name empty? " << table_name.empty();
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
......
......@@ -38,8 +38,8 @@ namespace distributed {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
explicit RequestSendHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestSendHandler() {}
......@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
explicit RequestGetHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestGetHandler() {}
......@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
explicit RequestPrefetchHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
......@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
explicit RequestCheckpointHandler(int distributed_mode,
int checkpoint_notify_id)
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
......@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id)
: RequestHandler(sync_mode) {
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id)
: RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id;
}
virtual ~RequestNotifyHandler() {}
......
......@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
......@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
......
......@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, false));
new distributed::RequestSendHandler(!sync_mode, false));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, false));
new distributed::RequestGetHandler(!sync_mode, false));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
......
......@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true);
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
......
......@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
int distributed_mode = Attr<int>("distributed_mode");
bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin");
auto pserver_id = Attr<int>("pserver_id");
......@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode
<< ", fan_in:" << fan_in << ", end_point:" << endpoint
VLOG(4) << "pserver_id: " << pserver_id
<< ", distributed_mode:" << distributed_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id
<< ", lr_decay_block_id: " << lr_decay_block_id;
......@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num");
request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, dc_sgd));
new distributed::RequestSendHandler(distributed_mode, dc_sgd));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, dc_sgd));
new distributed::RequestGetHandler(distributed_mode, dc_sgd));
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode));
new distributed::RequestPrefetchHandler(distributed_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
distributed_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id));
request_notify_handler_.reset(new distributed::RequestNotifyHandler(
distributed_mode, lr_decay_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num);
......@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
if (sync_mode) {
if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
......@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name);
if (distributed_mode == distributed::DistributedMode::kGeo) {
distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name);
}
VLOG(2) << "RunAsyncLoop";
auto grad_to_block_id_str =
......@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<int>("distributed_mode",
"indicate distriubte training mode, 0 is sync, 1 is "
"fully-async, 2 is half-async, 3 is geo")
.SetDefault(0);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>(
......
......@@ -32,9 +32,11 @@ USE_OP(sum);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace d = paddle::operators::distributed
// global for simplicity.
std::unique_ptr<f::OperatorBase> listen_and_serv_op;
// global for simplicity.
std::unique_ptr<f::OperatorBase>
listen_and_serv_op;
int selected_port;
void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
......@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs.insert({"optimize_blocks", optimize_blocks});
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
attrs.insert({"distributed_mode", d::DistributedMode::kSync});
VLOG(4) << "before init op";
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
......
......@@ -72,7 +72,8 @@ void StartServer() {
}
TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
......
......@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program, program_guard, Program, Variable
from ..layer_helper import LayerHelper
from ..unique_name import generate as unique_name
from ..transpiler.distribute_transpiler import DistributedMode
import logging
__all__ = [
......@@ -240,7 +241,8 @@ class ListenAndServ(object):
'optimize_blocks': [
current_block
], # did not support multiple optimize blocks in layers
'sync_mode': True, # did not support async now in layers
'distributed_mode':
DistributedMode.SYNC, # did not support async now in layers
'grad_to_block_id': [""]
})
......
......@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None):
def get_trainer(self, config=None, sync_mode=True):
src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config)
t = self._transpiler_instance(config, sync_mode=True)
trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import gc
gc.set_debug(gc.DEBUG_COLLECTABLE)
class TranspilerTest(unittest.TestCase):
def setUp(self):
self.trainer_id = 0
self.trainers = 2
self.pservers = 2
# NOTE: we do not actually bind this port
self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
self.pserver1_ep = "127.0.0.1:6174"
self.pserver2_ep = "127.0.0.1:6175"
self.sync_mode = True
self.transpiler = None
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
def get_main_program(self):
main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main):
self.net_conf()
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None, sync_mode=True):
src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config, sync_mode=True)
trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program()
assert (src.num_blocks == 1)
assert (trainer_startup.num_blocks == src.num_blocks)
return trainer_main, trainer_startup
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
self.transpiler.transpile(
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
def transpiler_test_impl(self):
pass
def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del self.transpiler
del main
del startup
gc.collect()
class TestBasicModelAsync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
config.runtime_split_send_recv = True
pserver, startup = self.get_pserver(self.pserver1_ep, config, False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, False)
trainer, _ = self.get_trainer(config, False)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'recv', 'recv'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 1)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
class TestBasicModelHalfAsync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
config.runtime_split_send_recv = False
pserver, startup = self.get_pserver(self.pserver1_ep, config, False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, False)
trainer, _ = self.get_trainer(config, False)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'recv', 'recv', 'concat'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 2)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
class TestBasicModelSync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = True
config.runtime_split_send_recv = False
pserver, startup = self.get_pserver(self.pserver1_ep, config, True)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, True)
trainer, _ = self.get_trainer(config, True)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'send_barrier', 'recv', 'recv', 'fetch_barrier', 'concat'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 0)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops],
["sum", "scale", "sgd"])
if __name__ == "__main__":
unittest.main()
......@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
flag2=test_list_and_serv_run_empty_optimize_block.flag
for i in {1..10}; do
sleep 3s
sleep 6s
if [[ -f "${flag1}" && -f "${flag2}" ]]; then
echo "test_listen_and_serv_op exit"
exit 0
......
......@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = sync_mode
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
......@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(ps2)
# pserver2 have no parameter
......
......@@ -25,6 +25,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
from dist_test_utils import *
......@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": []
})
......
......@@ -26,6 +26,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
from dist_test_utils import *
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
def nce(input, weight, bias, sample_weight, labels, num_classes,
......@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": []
})
......
......@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler.details import VarStruct, VarsDistributed
from dist_test_utils import *
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
def run_pserver(pserver_id):
......@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": []
})
......
......@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG = False
class DistributedMode:
SYNC = 0
ASYNC = 1
HALF_ASYNC = 2
GEO = 3
def log(*args):
if PRINT_LOG:
print(args)
......@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
if self.config.split_method is None:
self.config.split_method = RoundRobin
if self.config.sync_mode:
self.distributed_mode = DistributedMode.SYNC
elif self.config.runtime_split_send_recv:
self.distributed_mode = DistributedMode.ASYNC
else:
self.distributed_mode = DistributedMode.HALF_ASYNC
global PRINT_LOG
if self.config.print_log:
PRINT_LOG = True
......@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
"endpoint": endpoint,
"pserver_id": self.pserver_endpoints.index(endpoint),
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"distributed_mode": self.distributed_mode,
"grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
"lr_decay_block_id": lr_decay_block_id,
......
......@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from .details import wait_server_ready, VarsDistributed
from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig, DistributedMode
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
......@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
"optimize_blocks": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"distributed_mode": DistributedMode.GEO,
"grad_to_block_id": param_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
"rpc_get_thread_num": self.server_config._rpc_get_thread_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册