未验证 提交 985bceac 编写于 作者: 1 123malin 提交者: GitHub

Bug fix for sparse recorder (#21969)

* test=develop, bug fix for sparse recorder
上级 7e2af4c9
...@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true); distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service( std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h" #include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
...@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
distributed::BRPCVariableResponse resp(request_send_h_->scope(), distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(), request_send_h_->dev_ctx(),
!request_send_h_->sync_mode()); request_send_h_->distributed_mode());
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0, PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!"); "parse iobuf to tensor error!");
......
...@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase { ...@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(
request_handler->dev_ctx(), request_handler->scope(), request_handler->dev_ctx(),
!request_handler->sync_mode())); request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase { ...@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(
request_handler->dev_ctx(), request_handler->scope(), request_handler->dev_ctx(),
!request_handler->sync_mode())); request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify); int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......
...@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; ...@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };
class RPCServer; class RPCServer;
class VarHandle { class VarHandle {
...@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr; ...@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(int distributed_mode)
: sync_mode_(sync_mode), : distributed_mode_(distributed_mode),
dev_ctx_(nullptr), dev_ctx_(nullptr),
executor_(nullptr), executor_(nullptr),
scope_(nullptr), scope_(nullptr),
...@@ -198,7 +200,7 @@ class RequestHandler { ...@@ -198,7 +200,7 @@ class RequestHandler {
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes. // Get attributes.
bool sync_mode() { return sync_mode_; } int distributed_mode() { return distributed_mode_; }
framework::Scope* scope() { return scope_; } framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; } const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ProgramDesc* program() { return program_; } framework::ProgramDesc* program() { return program_; }
...@@ -225,7 +227,7 @@ class RequestHandler { ...@@ -225,7 +227,7 @@ class RequestHandler {
const std::string& table_name = "") = 0; const std::string& table_name = "") = 0;
protected: protected:
const bool sync_mode_; const int distributed_mode_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_; framework::Executor* executor_;
......
...@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_->Complete(); rpc_server_->Complete();
} else { } else {
// Async // Async
if (!sync_mode_) { if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW( PADDLE_THROW(
...@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname); scope->Rename(varname, run_varname);
} }
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr = auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>(); scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
...@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name; << " table_name: " << table_name;
if (sync_mode_) { if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message"; VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet); rpc_server_->IncreaseBatchBarrier(kRequestGet);
...@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
VLOG(1) << "Table name empty? " << table_name.empty(); VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist " VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam( << AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname); varname);
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && }
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) { !table_name.empty()) {
std::vector<int64_t> updated_rows; std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
......
...@@ -38,8 +38,8 @@ namespace distributed { ...@@ -38,8 +38,8 @@ namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
explicit RequestSendHandler(bool sync_mode, bool enable_dc_asgd = false) explicit RequestSendHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) { : RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd; enable_dc_asgd_ = enable_dc_asgd;
} }
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
...@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler { ...@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
public: public:
explicit RequestGetHandler(bool sync_mode, bool enable_dc_asgd = false) explicit RequestGetHandler(int distributed_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) { : RequestHandler(distributed_mode) {
enable_dc_asgd_ = enable_dc_asgd; enable_dc_asgd_ = enable_dc_asgd;
} }
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
...@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name, ...@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
public: public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestPrefetchHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
...@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler {
public: public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id) explicit RequestCheckpointHandler(int distributed_mode,
: RequestHandler(sync_mode) { int checkpoint_notify_id)
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id; this->checkpoint_notify_id = checkpoint_notify_id;
} }
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
...@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
class RequestNotifyHandler final : public RequestHandler { class RequestNotifyHandler final : public RequestHandler {
public: public:
explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id) explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id)
: RequestHandler(sync_mode) { : RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id; this->lr_decay_block_id = lr_decay_block_id;
} }
virtual ~RequestNotifyHandler() {} virtual ~RequestNotifyHandler() {}
......
...@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) { ...@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
...@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) { ...@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
TEST(COMPLETE, CPU) { TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestSendHandler(true)); g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
......
...@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset( request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, false)); new distributed::RequestSendHandler(!sync_mode, false));
request_get_handler_.reset( request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, false)); new distributed::RequestGetHandler(!sync_mode, false));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), request_send_handler_.get(),
......
...@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true); distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service( std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
......
...@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode"); int distributed_mode = Attr<int>("distributed_mode");
bool dc_sgd = Attr<bool>("dc_asgd"); bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto pserver_id = Attr<int>("pserver_id"); auto pserver_id = Attr<int>("pserver_id");
...@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
int checkpoint_block_id = Attr<int>(kCheckpointBlockId); int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
int lr_decay_block_id = Attr<int>(kLRDecayBlockId); int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode VLOG(4) << "pserver_id: " << pserver_id
<< ", fan_in:" << fan_in << ", end_point:" << endpoint << ", distributed_mode:" << distributed_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id << ", checkpoint_block_id: " << checkpoint_block_id
<< ", lr_decay_block_id: " << lr_decay_block_id; << ", lr_decay_block_id: " << lr_decay_block_id;
...@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num"); auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num");
request_send_handler_.reset( request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, dc_sgd)); new distributed::RequestSendHandler(distributed_mode, dc_sgd));
request_get_handler_.reset( request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, dc_sgd)); new distributed::RequestGetHandler(distributed_mode, dc_sgd));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(distributed_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id)); distributed_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset( request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler()); new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset( request_notify_handler_.reset(new distributed::RequestNotifyHandler(
new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id)); distributed_mode, lr_decay_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num); request_send_handler_.get(), rpc_send_thread_num);
...@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit); signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
if (sync_mode) { if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready..."; VLOG(3) << "wait server thread to become ready...";
...@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id); prefetch_block_id_list, checkpoint_block_id);
} else { } else {
if (distributed_mode == distributed::DistributedMode::kGeo) {
distributed::AsyncSparseParamUpdateRecorder::Init( distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name); fan_in, sparse_grad_name_to_param_name);
}
VLOG(2) << "RunAsyncLoop"; VLOG(2) << "RunAsyncLoop";
auto grad_to_block_id_str = auto grad_to_block_id_str =
...@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<int>("distributed_mode",
"indicate distriubte training mode, 0 is sync, 1 is "
"fully-async, 2 is half-async, 3 is geo")
.SetDefault(0);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.") AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>( AddAttr<std::vector<framework::BlockDesc *>>(
......
...@@ -32,9 +32,11 @@ USE_OP(sum); ...@@ -32,9 +32,11 @@ USE_OP(sum);
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
namespace d = paddle::operators::distributed
// global for simplicity. // global for simplicity.
std::unique_ptr<f::OperatorBase> listen_and_serv_op; std::unique_ptr<f::OperatorBase>
listen_and_serv_op;
int selected_port; int selected_port;
void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) { void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
...@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) { ...@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs.insert({"optimize_blocks", optimize_blocks}); attrs.insert({"optimize_blocks", optimize_blocks});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})}); attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true}); attrs.insert({"distributed_mode", d::DistributedMode::kSync});
VLOG(4) << "before init op"; VLOG(4) << "before init op";
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
......
...@@ -72,7 +72,8 @@ void StartServer() { ...@@ -72,7 +72,8 @@ void StartServer() {
} }
TEST(SendNcclId, RPCServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new distributed::RequestSendHandler(true)); g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
......
...@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \ ...@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program, program_guard, Program, Variable default_startup_program, program_guard, Program, Variable
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..unique_name import generate as unique_name from ..unique_name import generate as unique_name
from ..transpiler.distribute_transpiler import DistributedMode
import logging import logging
__all__ = [ __all__ = [
...@@ -240,7 +241,8 @@ class ListenAndServ(object): ...@@ -240,7 +241,8 @@ class ListenAndServ(object):
'optimize_blocks': [ 'optimize_blocks': [
current_block current_block
], # did not support multiple optimize blocks in layers ], # did not support multiple optimize blocks in layers
'sync_mode': True, # did not support async now in layers 'distributed_mode':
DistributedMode.SYNC, # did not support async now in layers
'grad_to_block_id': [""] 'grad_to_block_id': [""]
}) })
......
...@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase): ...@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone() self.origin_prog = main.clone()
return main return main
def get_trainer(self, config=None): def get_trainer(self, config=None, sync_mode=True):
src = fluid.default_startup_program().clone() src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config) t = self._transpiler_instance(config, sync_mode=True)
trainer_main = t.get_trainer_program(wait_port=False) trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program() trainer_startup = fluid.default_startup_program()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import gc
gc.set_debug(gc.DEBUG_COLLECTABLE)
class TranspilerTest(unittest.TestCase):
def setUp(self):
self.trainer_id = 0
self.trainers = 2
self.pservers = 2
# NOTE: we do not actually bind this port
self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
self.pserver1_ep = "127.0.0.1:6174"
self.pserver2_ep = "127.0.0.1:6175"
self.sync_mode = True
self.transpiler = None
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
def get_main_program(self):
main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main):
self.net_conf()
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None, sync_mode=True):
src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config, sync_mode=True)
trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program()
assert (src.num_blocks == 1)
assert (trainer_startup.num_blocks == src.num_blocks)
return trainer_main, trainer_startup
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
self.transpiler.transpile(
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
def transpiler_test_impl(self):
pass
def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del self.transpiler
del main
del startup
gc.collect()
class TestBasicModelAsync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
config.runtime_split_send_recv = True
pserver, startup = self.get_pserver(self.pserver1_ep, config, False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, False)
trainer, _ = self.get_trainer(config, False)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'recv', 'recv'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 1)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
class TestBasicModelHalfAsync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
config.runtime_split_send_recv = False
pserver, startup = self.get_pserver(self.pserver1_ep, config, False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, False)
trainer, _ = self.get_trainer(config, False)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'recv', 'recv', 'concat'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 2)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
class TestBasicModelSync(TranspilerTest):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = True
config.runtime_split_send_recv = False
pserver, startup = self.get_pserver(self.pserver1_ep, config, True)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config, True)
trainer, _ = self.get_trainer(config, True)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'send_barrier', 'recv', 'recv', 'fetch_barrier', 'concat'
])
self.assertEqual(len(pserver.blocks), 3)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
self.assertEqual(pserver.blocks[0].ops[0].attr("distributed_mode"), 0)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops],
["sum", "scale", "sgd"])
if __name__ == "__main__":
unittest.main()
...@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag ...@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
flag2=test_list_and_serv_run_empty_optimize_block.flag flag2=test_list_and_serv_run_empty_optimize_block.flag
for i in {1..10}; do for i in {1..10}; do
sleep 3s sleep 6s
if [[ -f "${flag1}" && -f "${flag2}" ]]; then if [[ -f "${flag1}" && -f "${flag2}" ]]; then
echo "test_listen_and_serv_op exit" echo "test_listen_and_serv_op exit"
exit 0 exit 0
......
...@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id): ...@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.sync_mode = sync_mode config.sync_mode = sync_mode
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers) t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup) exe.run(pserver_startup)
...@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers, ...@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
config.slice_var_up = False config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers) t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(ps2) pserver_prog = t.get_pserver_program(ps2)
# pserver2 have no parameter # pserver2 have no parameter
......
...@@ -25,6 +25,7 @@ import paddle.fluid as fluid ...@@ -25,6 +25,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
from dist_test_utils import * from dist_test_utils import *
...@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode): ...@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks": [optimize_block], "optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0', "endpoint": '127.0.0.1:0',
"Fanin": 1, "Fanin": 1,
"sync_mode": True, "distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": [] "grad_to_block_id": []
}) })
......
...@@ -26,6 +26,7 @@ import paddle.fluid.core as core ...@@ -26,6 +26,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from dist_test_utils import * from dist_test_utils import *
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
def nce(input, weight, bias, sample_weight, labels, num_classes, def nce(input, weight, bias, sample_weight, labels, num_classes,
...@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode): ...@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks": [optimize_block], "optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0', "endpoint": '127.0.0.1:0',
"Fanin": 1, "Fanin": 1,
"sync_mode": True, "distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": [] "grad_to_block_id": []
}) })
......
...@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator ...@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler.details import VarStruct, VarsDistributed from paddle.fluid.transpiler.details import VarStruct, VarsDistributed
from dist_test_utils import * from dist_test_utils import *
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
def run_pserver(pserver_id): def run_pserver(pserver_id):
...@@ -56,7 +57,7 @@ def run_pserver(pserver_id): ...@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
"optimize_blocks": [optimize_block], "optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0', "endpoint": '127.0.0.1:0',
"Fanin": 1, "Fanin": 1,
"sync_mode": True, "distributed_mode": DistributedMode.SYNC,
"grad_to_block_id": [] "grad_to_block_id": []
}) })
......
...@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched ...@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG = False PRINT_LOG = False
class DistributedMode:
SYNC = 0
ASYNC = 1
HALF_ASYNC = 2
GEO = 3
def log(*args): def log(*args):
if PRINT_LOG: if PRINT_LOG:
print(args) print(args)
...@@ -313,6 +320,13 @@ class DistributeTranspiler(object): ...@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
if self.config.split_method is None: if self.config.split_method is None:
self.config.split_method = RoundRobin self.config.split_method = RoundRobin
if self.config.sync_mode:
self.distributed_mode = DistributedMode.SYNC
elif self.config.runtime_split_send_recv:
self.distributed_mode = DistributedMode.ASYNC
else:
self.distributed_mode = DistributedMode.HALF_ASYNC
global PRINT_LOG global PRINT_LOG
if self.config.print_log: if self.config.print_log:
PRINT_LOG = True PRINT_LOG = True
...@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object): ...@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
"endpoint": endpoint, "endpoint": endpoint,
"pserver_id": self.pserver_endpoints.index(endpoint), "pserver_id": self.pserver_endpoints.index(endpoint),
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "distributed_mode": self.distributed_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param, "sparse_grad_to_param": sparse_grad_to_param,
"lr_decay_block_id": lr_decay_block_id, "lr_decay_block_id": lr_decay_block_id,
......
...@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \ ...@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from .details import wait_server_ready, VarsDistributed from .details import wait_server_ready, VarsDistributed
from .details import delete_ops from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig, DistributedMode
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
...@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler): ...@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
"optimize_blocks": optimize_block, "optimize_blocks": optimize_block,
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "distributed_mode": DistributedMode.GEO,
"grad_to_block_id": param_to_block_id, "grad_to_block_id": param_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param, "sparse_grad_to_param": sparse_grad_to_param,
"rpc_get_thread_num": self.server_config._rpc_get_thread_num, "rpc_get_thread_num": self.server_config._rpc_get_thread_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册