未验证 提交 8b50ad80 编写于 作者: T tangwei12 提交者: GitHub

checkpoint at distributed training (#14854)

checkpoint for distributed training.
上级 07dc5a15
......@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
const std::string method = "SendRPC";
const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
......@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(100) << "ProcGetResponse";
VLOG(4) << "ProcGetResponse";
framework::Variable* outvar = nullptr;
// get response's trainer_id is not used
int trainer_id;
......@@ -127,39 +127,54 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out);
}
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}
VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& rpc_path,
int64_t time_out) {
VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
framework::AsyncIO(
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
......@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "PrefetchRPC";
const std::string method = kPrefetchRPC;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
......@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "BatchBarrierRPC";
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = "FetchBarrierRPC";
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC";
const std::string method = kSendMonomerFetchBarrierRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendCompleteRPC";
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
const std::string method = "CheckPointNotifyRPC";
const std::string method = kCheckPointNotifyRPC;
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
......
......@@ -186,6 +186,13 @@ class GRPCClient : public RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
......@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, const std::string& rpc,
int64_t time_out);
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline);
private:
grpc::CompletionQueue cq_;
......
......@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetNoBarrier final : public RequestBase {
public:
explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetNoBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
......@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
this);
......
......@@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
kGetVariableNoBarrier,
kGetMonomerVariable,
kGetMonomerBarrier,
};
......@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kGetVariableNoBarrier:
return "/sendrecv.SendRecvService/GetVariableNoBarrier";
case GrpcMethod::kGetMonomerVariable:
return "/sendrecv.SendRecvService/GetMonomerVariable";
case GrpcMethod::kGetMonomerBarrier:
......
......@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
......@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname;
VLOG(4) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
......@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
// get var from pserver immediately without barriers
string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, without_barrier_piece)) {
var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
VLOG(4) << "Get var " << var_name_piece << " with "
<< WITHOUT_BARRIER_MESSAGE;
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
......
......@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool enable_dc_asgd_;
};
class RequestGetNoBarrierHandler final : public RequestHandler {
public:
RequestGetNoBarrierHandler() : RequestHandler(false) {}
virtual ~RequestGetNoBarrierHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
......
......@@ -43,6 +43,13 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerVariable(
......
......@@ -17,8 +17,14 @@ package sendrecv;
option cc_generic_services = @cc_generic_services@;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
......@@ -27,12 +33,17 @@ service SendRecvService {
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
// It can be: LoDTensorSelectedRows or NCCL_ID
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// VariableMessage is serialized paddle variable message.
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
......@@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
......
......@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
......@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......
......@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
......@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_get_no_barrier_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
......
......@@ -27,34 +27,54 @@ namespace operators {
class RecvOp : public framework::OperatorBase {
public:
RecvOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto outs = Outputs("Out");
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
}
};
......@@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
.SetDefault(true);
AddAttr<std::vector<std::string>>(
"varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
}
};
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
......
......@@ -1696,12 +1696,20 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
# for distribute
# for distribute training
# _is_distributed = True if under distributed training
self._is_distributed = False
# _is_chief = True if the trainer is the first one, usually No.0
self._is_chief = False
self._slice_vars_and_attrs = []
# _parameters_on_pservers records all the parameters distributed on parameter servers.
self._parameters_on_pservers = None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self._endpoints = []
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self._ps_endpoint = None
# trainers_endpoints, it is used for distribution.
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
@property
......@@ -2232,8 +2240,9 @@ class Program(object):
"Program")
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs
self._parameters_on_pservers = other._parameters_on_pservers
self._endpoints = other._endpoints
self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other):
......
......@@ -19,6 +19,7 @@ import errno
import time
import shutil
import six
from functools import reduce
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator
......@@ -183,8 +184,6 @@ def save_vars(executor,
# NOTE: don't save the variable which type is RAW
if each_var.type == core.VarDesc.VarType.RAW:
continue
if each_var.name == main_program._distributed_lookup_table:
continue
new_var = _clone_var_in_block_(save_block, each_var)
if filename is None:
save_block.append_op(
......@@ -206,16 +205,6 @@ def save_vars(executor,
outputs={},
attrs={'file_path': os.path.join(dirname, filename)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = main_program._endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = main_program._distributed_lookup_table
save_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(save_program)
......@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None):
filename=filename)
def _save_distributed_persistables(executor, dirname, main_program):
"""
save_persistables for distributed training.
the method will do things listed below:
1.save part of persistable variables on trainer.
2.receive "remote prefetch variables" from parameter servers and merge them.
3.save "distributed lookup table" on parameter servers.
4.receive "optimizer variables" from parameter servers and merge them.
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The saving directory path.
main_program(Program): The program whose parameters will be
saved. the main_program must be the trainer_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
train_program = t.get_trainer_program()
_save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
"""
def __save_remote_params(executor, dirname, remote_params_map):
"""
recive params on pserver through rpc.
if the params are be sliced, will concat them to one, then save it.
"""
if not remote_params_map:
return
prog = Program()
block = prog.global_block()
# recv optimize vars from pserver
for name, remote_params in remote_params_map.items():
origin_var = None
is_slice = False
slice_vars = [0] * len(remote_params)
slice_var_names = [""] * len(remote_params)
endpoints = [""] * len(remote_params)
for idx, optimizer in enumerate(remote_params):
origin = optimizer.origin
slice = optimizer.slice
is_slice = optimizer.is_slice
block_id = optimizer.block_id
endpoint = optimizer.endpoint
if idx == 0:
origin_var = block.create_var(
name=origin.name,
type=origin.type,
shape=origin.shape,
dtype=origin.dtype,
persistable=True)
slice_var = block.create_var(
name="{}.slice.{}".format(slice.name, idx),
type=slice.type,
shape=slice.shape,
dtype=slice.dtype,
persistable=True)
index = block_id if is_slice else idx
slice_vars[index] = slice_var
slice_var_names[index] = slice.name
endpoints[index] = endpoint
if is_slice:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": slice_vars},
attrs={
"epmap": endpoints,
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='concat',
inputs={'X': slice_vars},
outputs={'Out': origin_var},
attrs={})
else:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": [origin_var]},
attrs={
"epmap": endpoints[:1],
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='save',
inputs={'X': [origin_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, origin_var.name)})
block.append_op(type='delete_var', inputs={'X': slice_vars})
executor.run(prog)
def __save_distributed_lookup_tables(executor, dirname,
distributed_lookup_table, endpoints):
"""
because the distributed lookup table may too huge to merge and save at one place,
it will be saved at parameter server independent respectively.
the save directory is dirname/"__lookup_table__".
"""
prog = Program()
block = prog.global_block()
# if there is lookup table, the trainer 0 will notify all pserver to save.
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = distributed_lookup_table
block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(prog)
def __exclude_vars(exclude_var_names=[]):
def is_valid(var):
if var.name in exclude_var_names:
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
return is_valid
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_save_distributed_persistables' just be designed for distributed training."
)
remote_params_map = main_program._parameters_on_pservers.get_distributed_vars_by_vtypes(
["Optimizer", "RemotePrefetch"], groupby=True)
exclude_var_names = []
if remote_params_map:
exclude_var_names.extend(remote_params_map.keys())
if main_program._distributed_lookup_table:
if isinstance(main_program._distributed_lookup_table, list):
exclude_var_names.extend(main_program._distributed_lookup_table)
else:
exclude_var_names.append(main_program._distributed_lookup_table)
local_vars = list(
filter(__exclude_vars(exclude_var_names), main_program.list_vars()))
save_vars(
executor, main_program=main_program, dirname=dirname, vars=local_vars)
if main_program._is_chief:
if remote_params_map:
__save_remote_params(executor, dirname, remote_params_map)
if main_program._distributed_lookup_table:
__save_distributed_lookup_tables(
executor, dirname, main_program._distributed_lookup_table,
main_program._endpoints)
def save_persistables(executor, dirname, main_program=None, filename=None):
"""
This function filters out all variables with `persistable==True` from the
......@@ -301,6 +470,12 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if main_program and main_program._is_distributed:
_save_distributed_persistables(
executor, dirname=dirname, main_program=main_program)
else:
save_vars(
executor,
dirname=dirname,
......@@ -402,17 +577,11 @@ def load_vars(executor,
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
load_slice_vars = []
for each_var in main_program._slice_vars_and_attrs:
load_slice_vars.append(each_var[2].name)
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
if each_var.name in load_slice_vars:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
......@@ -435,10 +604,6 @@ def load_vars(executor,
attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None):
"""
......@@ -521,6 +686,11 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.load_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if main_program and main_program._is_distributed:
_load_distributed_persistables(
executor, dirname=dirname, main_program=main_program)
else:
load_vars(
executor,
dirname=dirname,
......@@ -529,6 +699,123 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
filename=filename)
def _load_distributed_persistables(executor, dirname, main_program=None):
"""
customized load_persistables for distributed training.
it should be used on parameter server,
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The load directory path.
main_program(Program): The program whose parameters will be
loaded. the main_program must be the pserver_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
pserver_prog = t.get_pserver_program(...)
_load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
"""
def __is_distributed_part_var(varname):
trainer_idx = varname.find(".trainer_")
block_idx = varname.find(".block")
return trainer_idx or block_idx
def __load_persistable_vars(executor, dirname, need_load_vars):
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for param in need_load_vars:
origin_var = param.origin
slice_var = param.slice
is_slice = param.is_slice
offset = param.offset
if is_slice:
origin = load_block.create_var(
name="{}.load".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
slice = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])
load_block.append_op(
type="slice",
inputs={'Input': origin},
outputs={'Out': slice},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(origin)
else:
origin = load_block.create_var(
name="{}".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_load_distributed_persistables' just be designed for distributed training."
)
if not main_program._ps_endpoint:
raise ValueError(
"'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
)
need_load_vars = main_program._parameters_on_pservers.get_distributed_vars_by_ep(
main_program._ps_endpoint)
__load_persistable_vars(executor, dirname, need_load_vars)
def prepend_feed_ops(inference_program,
feed_target_names,
feed_holder_name='feed'):
......@@ -795,52 +1082,6 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
......@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + slice_var.shape[0]
orig_var_name = orig_var.name
orig_var.name = "{}.origin".format(orig_var_name)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, orig_var_name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(clone_orig_var)
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
......@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode)
args.trainers, args.sync_mode, False,
args.current_endpoint)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
......@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe.run(startup_prog)
if need_load and model_dir:
self._load_persistable_vars(exe, model_dir, startup_prog)
fluid.io.load_persistables(exe, model_dir, pserver_prog)
exe.run(pserver_prog)
def run_trainer(self, args):
......@@ -158,7 +160,9 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save = bool(int(os.getenv("SAVE", "0")))
model_dir = os.getenv("MODEL_DIR", "")
save_mode = os.getenv("SAVE_MODE", "")
if save_mode == "LOCAL":
if need_save:
for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name],
......@@ -166,12 +170,37 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
if need_save and model_dir:
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
))
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST":
skip_steps = int(os.getenv("SKIP_STEPS"))
loss = None
if need_save:
for idx in six.moves.xrange(8):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir and idx == skip_steps and args.trainer_id == 0:
io.save_persistables(startup_exe, model_dir,
trainer_prog)
else:
for idx in six.moves.xrange(8):
data = get_data()
if idx <= skip_steps:
continue
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data))
if six.PY2:
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else:
raise Exception("save_mode must be LOCAL or DIST")
if __name__ == "__main__":
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
......
......@@ -75,8 +75,12 @@ def get_loss(cos_q_pt, cos_q_nt):
return avg_cost
def get_optimizer():
# SGD optimizer
def get_optimizer(op="sgd"):
if op.upper() == "sgd".upper():
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
elif op.upper() == "adam".upper():
optimizer = fluid.optimizer.Adam(learning_rate=base_lr)
else:
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
return optimizer
......@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer()
opt = os.getenv('OPTIMIZER', 'sgd')
opt = get_optimizer(opt)
opt.minimize(avg_cost)
# Reader
......
......@@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints,
trainers,
sync_mode,
dc_asgd=False):
dc_asgd=False,
current_endpoint=None):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
......@@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program=main_program,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t
def run_pserver(self, args):
......
......@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
......@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1'
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'LOCAL',
}
self.check_with_place(
"dist_save_load.py",
delta=0,
check_error_log=False,
need_envs=need_envs)
class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
model_dir = tempfile.mkdtemp()
save_env = {}
save_env["SAVE_MODE"] = "DIST"
save_env["SAVE"] = "1"
save_env["MODEL_DIR"] = model_dir
save_env.update(required_envs)
tr0_var_1, tr1_var_1 = self._run_cluster(model_file, save_env,
check_error_log)
load_env = {}
load_env["LOAD"] = "1"
load_env["MODEL_DIR"] = model_dir
load_env.update(required_envs)
tr0_var_2, tr1_var_2 = self._run_cluster(model_file, load_env,
check_error_log)
shutil.rmtree(model_dir)
train0_1_np = np.array(tr0_var_1)
train1_1_np = np.array(tr1_var_1)
train0_2_np = np.array(tr0_var_2)
train1_2_np = np.array(tr1_var_2)
self.assertAlmostEqual(
train0_1_np.all(), train0_2_np.all(), delta=delta)
self.assertAlmostEqual(
train1_1_np.all(), train1_2_np.all(), delta=delta)
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'DIST',
'OPTIMIZER': 'ADAM',
'SKIP_STEPS': str(np.random.randint(2, 6))
}
self.check_with_place(
"dist_save_load.py",
......
......@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = six.moves.reduce(
lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
six.moves.reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) +
six.moves.reduce(lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
vars_ps1 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver1_ep)
vars_ps2 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver2_ep)
self.assertTrue(vars_ps1)
self.assertTrue(vars_ps2)
for idx in six.moves.xrange(len(vars_ps1)):
total_numel = 0
ps1_numel, ps2_numel = 0, 0
ps1_var = vars_ps1[idx]
if not ps1_var.is_slice:
total_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].slice.shape)
else:
ps2_var = None
for var in vars_ps2:
if var.origin.name == ps1_var.origin.name:
ps2_var = var
break
total_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.slice.shape)
ps2_numel = six.moves.reduce(lambda x, y: x * y,
ps2_var.slice.shape)
self.assertEqual(total_numel, ps1_numel + ps2_numel)
class TestNCCL2Transpile(TranspilerTest):
......
......@@ -39,7 +39,7 @@ from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \
default_startup_program, Block, \
Parameter, grad_var_name
Parameter, Variable, grad_var_name
from .details import *
from ..distribute_lookup_table import find_distributed_lookup_table
from functools import reduce
......@@ -62,6 +62,260 @@ def log(*args):
print(args)
class VarStruct(object):
"""
record part properties of a Variable in python.
"""
def __init__(self, name, shape, dtype, type, lod_level, persistable):
self.name = name
self.shape = shape
self.dtype = dtype
self.type = type
self.lod_level = lod_level
self.persistable = persistable
class VarDistributed(object):
"""
a class to record the var distributed on parameter servers.
the class will record the relationship between origin var and slice var.
the slice var's properties, such as type/shape/offset/endpoint.
"""
def __init__(self,
origin_var,
slice_var,
is_slice=None,
block_id=None,
offset=None,
vtype=None,
endpoint=None):
"""
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
"""
if isinstance(origin_var, Variable):
self.origin = self.__create_var_struct(origin_var)
else:
self.origin = origin_var
if isinstance(slice_var, Variable):
self.slice = self.__create_var_struct(slice_var)
else:
self.slice = slice_var
if self.equal(self.origin, self.slice):
self.is_slice = False
self.block_id = 0
self.offset = 0
else:
self.is_slice = True
self.block_id = 0
self.offset = 0
if is_slice is not None:
self.is_slice = is_slice
if block_id is not None:
self.block_id = block_id
if offset is not None:
self.offset = offset
self.vtype = vtype
self.endpoint = endpoint
@staticmethod
def __create_var_struct(var):
return VarStruct(var.name, var.shape, var.dtype, var.type,
var.lod_level, var.persistable)
@staticmethod
def equal(var1, var2):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct)
return var1.name == var2.name and \
var1.type == var2.type and \
var1.shape == var2.shape and \
var1.dtype == var2.dtype and \
var1.lod_level == var2.lod_level and \
var1.persistable == var2.persistable
def __str__(self):
origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})". \
format(i="{", e="}", name=self.origin.name, type=self.origin.type,
shape=self.origin.shape, dtype=self.origin.dtype)
slice_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})" \
".slice({is_slice}).block({block_id}).offset({offset})". \
format(i="{", e="}", name=self.slice.name, type=self.slice.type,
shape=self.slice.shape, dtype=self.slice.dtype,
is_slice=self.is_slice, block_id=self.block_id, offset=self.offset)
return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format(
self.vtype, origin_var_str, slice_var_str, self.endpoint)
class VarsDistributed(object):
"""
a gather about VarDistributed with many methods to find distributed vars.
through the class, we can get overview about the distributed parameters on parameter servers.
this class may centralized and convenient for developer to manage and get variable's distribute.
other module can also use this to find variables such io.py.
"""
def __init__(self):
self.distributed_vars = []
def add_distributed_var(self,
origin_var,
slice_var,
is_slice=None,
block_id=None,
offset=None,
vtype=None,
endpoint=None):
"""
add distributed var in this.
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
Returns:
None
"""
self.distributed_vars.append(
VarDistributed(origin_var, slice_var, is_slice, block_id, offset,
vtype, endpoint))
def get_distributed_var_by_slice(self, var_name):
"""
get distributed var by conditions.
Args:
var_name(str): slice var name, such as "w.traier0.block1"
Returns:
VarDistributed: distributed var.
"""
for dist_var in self.distributed_vars:
if dist_var.slice.name == var_name:
return dist_var
return None
@staticmethod
def equal(var1, var2):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
return var1.name == var2.name and \
var1.type == var2.type and \
var1.shape == var2.shape and \
var1.dtype == var2.dtype and \
var1.lod_level == var2.lod_level and \
var1.persistable == var2.persistable
def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint):
"""
get distributed var by conditions.
Args:
origin_var_name(str):
endpoint(str): the parameter endpoint, such as "127.0.0.1:1001"
Returns:
VarDistributed: distributed var.
"""
for dist_var in self.distributed_vars:
if dist_var.origin.name == origin_var_name and dist_var.endpoint == endpoint:
return dist_var
return None
def get_distributed_vars_by_vtypes(self, vtypes, groupby=False):
"""
get distributed vars by conditions.
Args:
vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch"
groupby(bool|False): group by origin var or not.
Returns:
list: distributed var list.
dict: distributed var map when groupby=True
"""
vtype_vars = []
for var in self.distributed_vars:
if var.vtype in vtypes:
vtype_vars.append(var)
if not groupby:
return vtype_vars
params_map = {}
for var in vtype_vars:
origin_var_name = var.origin.name
if origin_var_name in params_map.keys():
optimizers = params_map.get(origin_var_name)
else:
optimizers = []
optimizers.append(var)
params_map[origin_var_name] = optimizers
return params_map
def get_distributed_vars_by_ep(self, endpoint, vtype=None):
"""
get distributed vars by conditions.
Args:
endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001"
vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch"
Returns:
list: distributed var list.
"""
endpoint_vars = []
for var in self.distributed_vars:
if var.endpoint == endpoint:
endpoint_vars.append(var)
if not vtype:
return endpoint_vars
vtype_vars = []
for var in endpoint_vars:
if var.vtype == vtype:
vtype_vars.append(var)
return vtype_vars
def overview(self):
"""
get the overview string about all params on all parameter servers.
Returns:
Str: overview string.
"""
vars_str = []
for var in self.distributed_vars:
vars_str.append(str(var))
return "\n".join(vars_str)
class VarBlock:
def __init__(self, varname, offset, size):
self.varname = varname
......@@ -223,16 +477,13 @@ class DistributeTranspiler(object):
trainer_id,
trainers,
current_endpoint,
startup_program=None,
wait_port=True):
startup_program=None):
if not startup_program:
startup_program = default_startup_program()
if trainer_id >= 0:
worker_endpoints = trainers.split(",")
# send NCCL_ID to others or recv from trainer 0
worker_endpoints.remove(current_endpoint)
if trainer_id == 0 and wait_port:
wait_server_ready(worker_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
......@@ -313,13 +564,11 @@ class DistributeTranspiler(object):
if self.config.mode == "nccl2":
assert (isinstance(trainers, str))
self.origin_program._trainers_endpoints = trainers.split(",")
self._transpile_nccl2(
trainer_id,
trainers,
current_endpoint,
startup_program=startup_program,
wait_port=self.config.wait_port)
startup_program=startup_program)
return
self.trainer_num = trainers
......@@ -327,6 +576,7 @@ class DistributeTranspiler(object):
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.vars_overview = VarsDistributed()
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
......@@ -347,6 +597,7 @@ class DistributeTranspiler(object):
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._ps_endpoint = current_endpoint
self.origin_program._is_chief = self.trainer_id == 0
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
......@@ -454,6 +705,10 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
distributed_var = self.vars_overview.get_distributed_var_by_slice(
recv_vars[i].name)
distributed_var.endpoint = ep
# step4: Concat the parameters splits together after recv.
all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
......@@ -480,6 +735,12 @@ class DistributeTranspiler(object):
recv_op_role_var_name = splited_trainer_grad[0].name
if param_varname in self.sparse_param_to_height_sections:
for table_name in table_names:
distributed_var = self.vars_overview.get_distributed_var_by_slice(
table_name)
distributed_var.vtype = "RemotePrefetch"
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_remote_sparse_update_op(
......@@ -532,6 +793,9 @@ class DistributeTranspiler(object):
pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
self._get_distributed_optimizer_vars()
self.origin_program._parameters_on_pservers = self.vars_overview
def get_trainer_program(self, wait_port=True):
"""
Get transpiled trainer side program.
......@@ -541,6 +805,7 @@ class DistributeTranspiler(object):
"""
# remove optimize ops and add a send op to main_program
# FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay?
lr_ops = self._get_lr_ops()
delete_ops(self.origin_program.global_block(), self.optimize_ops)
delete_ops(self.origin_program.global_block(), lr_ops)
......@@ -665,9 +930,14 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
sys.stderr.write(
"get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.\n"
)
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
pserver_program._copy_dist_param_info_from(self.origin_program)
# step2: Create vars to receive vars at parameter servers.
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]:
......@@ -703,9 +973,6 @@ class DistributeTranspiler(object):
else:
recv_inputs.append(single_trainer_var)
self._slice_params_and_optimizes = self._get_slice_vars_and_attrs(
endpoint)
# step 3
# Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops
......@@ -882,10 +1149,6 @@ class DistributeTranspiler(object):
outputs={},
attrs=attrs)
# add distributed attrs
pserver_program._slice_vars_and_attrs = list(
self._slice_params_and_optimizes.values())
pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively.
self.pserver_program = pserver_program
......@@ -984,30 +1247,88 @@ class DistributeTranspiler(object):
inputs={"X": startup_param_var},
outputs={"Out": startup_tmpvar})
# add slice vars
s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs
return s_prog
def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = {}
# ====================== private transpiler functions =====================
def _get_slice_var_info(self, slice_var):
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name)
block_idx = 0
offset = 0
is_slice = False
orig_var_name, block_name, _ = self._get_varname_parts(slice_var.name)
if not block_name:
continue
return is_slice, block_idx, offset
block_idx = int(block_name.split(block_suffix)[1])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_dim0 = 0
slice_vars = self.param_var_mapping[orig_var_name]
orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:])
for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0]
slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param]
return slice_vars_and_attrs
# ====================== private transpiler functions =====================
offset = skip_dim0 * orig_dim1_flatten
is_slice = True
return is_slice, block_idx, offset
def _get_distributed_optimizer_vars(self):
def _get_distributed_optimizer_var(endpoint):
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_optimizer_op(op) and self._is_opt_op_on_pserver(
endpoint, op):
opt_op_on_pserver.append(op)
for opt_op in opt_op_on_pserver:
dist_var = None
for key in opt_op.input_names:
if key == "Param":
param_name = opt_op.input(key)[0]
dist_var = self.vars_overview.get_distributed_var_by_origin_and_ep(
param_name, endpoint)
break
for key in opt_op.input_names:
if key in ["Param", "Grad", "LearningRate"]:
continue
origin_var = self.origin_program.global_block().vars[
opt_op.input(key)[0]]
# update accumulator variable shape
new_shape = self._get_optimizer_input_shape(
opt_op.type, key, origin_var.shape,
dist_var.slice.shape)
if new_shape == dist_var.slice.shape:
splited_var = VarStruct(
name=origin_var.name,
shape=new_shape,
dtype=origin_var.dtype,
type=origin_var.type,
lod_level=origin_var.lod_level,
persistable=origin_var.persistable)
self.vars_overview.add_distributed_var(
origin_var=origin_var,
slice_var=splited_var,
is_slice=dist_var.is_slice,
block_id=dist_var.block_id,
offset=dist_var.offset,
vtype="Optimizer",
endpoint=endpoint)
else:
self.vars_overview.add_distributed_var(
origin_var=origin_var,
slice_var=origin_var,
is_slice=False,
block_id=0,
offset=0,
vtype="Optimizer",
endpoint=endpoint)
for ep in self.pserver_endpoints:
_get_distributed_optimizer_var(ep)
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
......@@ -1093,6 +1414,22 @@ class DistributeTranspiler(object):
# origin_param_name -> [splited_param_vars]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
for orig_name, splited_vars in self.param_var_mapping.items():
orig_var = self.origin_program.global_block().var(orig_name)
for splited_var in splited_vars:
is_slice, block_id, offset = self._get_slice_var_info(
splited_var)
self.vars_overview.add_distributed_var(
origin_var=orig_var,
slice_var=splited_var,
block_id=block_id,
offset=offset,
is_slice=is_slice,
vtype="Param")
# origin_grad_name -> [splited_grad_vars]
self.grad_var_mapping = self._create_vars_from_blocklist(
self.origin_program,
......@@ -1729,13 +2066,6 @@ class DistributeTranspiler(object):
shape=new_shape)
new_inputs[key] = tmpvar
# var shape been changed
if new_shape != var.shape:
slice_var_args = self._slice_params_and_optimizes[
param_var.name]
self._slice_params_and_optimizes[
var.name] = [var, slice_var_args[1], tmpvar]
# change output's ParamOut variable
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
......@@ -1763,7 +2093,7 @@ class DistributeTranspiler(object):
# skip per trainer vars
if g.name.find(".trainer_") == -1:
# only param or grads have splited blocks
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
if self._orig_varname(g.name) in self.grad_name_to_param_name or \
self._orig_varname(g.name) in self.param_name_to_grad_name:
grad_block = g
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册