未验证 提交 306236c2 编写于 作者: W Wu Yi 提交者: GitHub

feature/DC asgd (#12722)

* wip

* add ref_by_trainer_id op

* ready to test

* fix ref inputs

* refine rpc_op_handle

* fix merge bug
上级 c3cbf0b8
......@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {}
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (ir::IsControlDepVar(*in->Node())) { // HACK
if (ir::IsControlDepVar(*in->Node())) {
continue;
}
if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_->Run(*tmp_scope, place_);
this->RunAndRecordEvent([this] {
op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
place_);
});
}
std::string RPCOpHandle::Name() const { return name_; }
......
......@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
::paddle::operators::distributed::GRPCClient>(0)
->SendComplete();
#endif
}
......
......@@ -38,9 +38,10 @@ class CheckpointNotifyOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
......@@ -63,6 +64,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC(
CheckpointNotify operator
......
......@@ -79,7 +79,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
......@@ -105,7 +105,10 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
// get response's trainer_id is not used
int trainer_id;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
}
template <typename T>
......@@ -135,6 +138,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
......
......@@ -34,8 +34,8 @@ namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id) {
platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
......@@ -45,6 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
size_t payload_size;
request.set_varname(name);
request.set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
......@@ -147,11 +148,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
......
......@@ -38,12 +38,13 @@ typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());
const std::string& out_varname = std::string(),
const int trainer_id = 0);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var);
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
......
......@@ -102,9 +102,10 @@ class RequestSend final : public RequestBase {
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar);
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
}
......@@ -133,13 +134,14 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname;
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar);
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
......@@ -179,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
......@@ -187,7 +190,8 @@ class RequestPrefetch final : public RequestBase {
// out var must be created in local scope!
framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
......@@ -225,12 +229,13 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
trainer_id, checkpoint_dir);
Finish(reply_, &responder_);
}
......
......@@ -293,6 +293,14 @@ int GRPCVariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kTrainerIdFieldNumber: {
uint64_t trainer_id = 0;
if (!input.ReadVarint64(&trainer_id)) {
return tag;
}
meta_.set_trainer_id(trainer_id);
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
......
......@@ -190,6 +190,7 @@ class RequestHandler {
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0;
protected:
......
......@@ -36,6 +36,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname;
......@@ -76,6 +77,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) {
......@@ -88,6 +90,19 @@ bool RequestGetHandler::Handle(const std::string& varname,
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname);
}
}
......@@ -98,6 +113,7 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
......@@ -113,6 +129,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
......
......@@ -36,20 +36,34 @@ namespace distributed {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
explicit RequestSendHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
explicit RequestGetHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
};
class RequestPrefetchHandler final : public RequestHandler {
......@@ -58,6 +72,7 @@ class RequestPrefetchHandler final : public RequestHandler {
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
};
......@@ -70,6 +85,7 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
private:
......
......@@ -24,6 +24,7 @@ namespace distributed {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
int RPCClient::trainer_id_ = 0;
} // namespace distributed
} // namespace operators
......
......@@ -72,14 +72,15 @@ class RPCClient {
virtual bool Wait() = 0;
template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
static RPCClient* GetInstance(int trainer_id) {
std::call_once(init_flag_, &RPCClient::Init<T>, trainer_id);
return rpc_client_.get();
}
// Init is called by GetInstance.
template <typename T>
static void Init() {
static void Init(int trainer_id) {
trainer_id_ = trainer_id;
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T());
rpc_client_->InitImpl();
......@@ -88,6 +89,8 @@ class RPCClient {
protected:
virtual void InitImpl() {}
// each trainer have exact one trainer id, it should be static
static int trainer_id_;
private:
static std::once_flag init_flag_;
......
......@@ -125,7 +125,7 @@ TEST(PREFETCH, CPU) {
g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
......@@ -165,7 +165,7 @@ TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
......
......@@ -79,6 +79,7 @@ message VariableMessage {
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
}
message VoidMessage {}
......@@ -92,6 +92,8 @@ class VariableResponse {
return scope_->FindVar(meta_.varname());
}
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
......
......@@ -37,7 +37,8 @@ class FetchBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
......@@ -61,6 +62,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
......
......@@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list");
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
......
......@@ -218,23 +218,26 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
VLOG(2) << "RunAsyncLoop";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
for (const auto &grad_and_id : grad_to_block_id_str) {
DoubleFindMap<std::string, int32_t> grad_to_block_id;
auto append_block_maps = [](DoubleFindMap<std::string, int32_t> *out_map,
const std::string &grad_and_id) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
(*out_map)[pieces[0]] = block_id;
};
for (const auto &grad_and_id : grad_to_block_id_str) {
append_block_maps(&grad_to_block_id, grad_and_id);
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
......@@ -244,15 +247,22 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
// execute global block if needed, block id 1 in the program is global
// block if it's not bind to a grad var for it's update.
if (block_list[0] == 1 &&
grad_to_block_id.find_value(static_cast<int32_t>(1)) ==
grad_to_block_id.end()) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
grad_to_prepared_ctx, param_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
auto blkid = block_list[i];
auto it = grad_to_block_id.find_value(blkid);
if (it != grad_to_block_id.end()) {
grad_to_prepared_ctx[it->first] = optimize_prepared[i];
}
}
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
......@@ -315,6 +325,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
......@@ -328,8 +339,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode));
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, dc_sgd));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, dc_sgd));
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
......@@ -443,6 +456,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
......@@ -37,6 +38,17 @@ constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class ListenAndServOp : public framework::OperatorBase {
public:
ListenAndServOp(const std::string& type,
......
......@@ -42,7 +42,8 @@ class PrefetchOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
......@@ -69,6 +70,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) result "
"to be fetched from parameter server")
.AsDuplicable();
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
......
......@@ -42,7 +42,8 @@ class RecvOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
......@@ -73,6 +74,7 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
#include <string>
namespace paddle {
namespace operators {
class RefByTrainerIdOp : public framework::OperatorWithKernel {
public:
RefByTrainerIdOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Input(X) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
"TrainerId should be a scalar.");
// Out's shape is determined at runtime.
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
}
};
class RefByTrainerIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor list.").AsDuplicable();
AddInput("TrainerId", "(Tensor) Scalar int, the trainer id runtime value.");
AddOutput("Out", "(Tensor) Return one tensor reference of X[trainer_id]");
AddComment(R"DOC(
**RefByTrainerId operator**
Return a reference of a tensor, using trainer_id as the index to find from the input.
$$Out = X[TrainerId]$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ref_by_trainer_id, ops::RefByTrainerIdOp,
ops::RefByTrainerIdOpMaker);
REGISTER_OP_CPU_KERNEL(
ref_by_trainer_id,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, double>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL(
ref_by_trainer_id,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RefByTrainerIdKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
auto stream = context.cuda_device_context().stream();
memory::Copy<>(platform::CPUPlace(), &trainer_id,
boost::get<platform::CUDAPlace>(context.GetPlace()),
trainer_id_data, sizeof(int64_t), stream);
#endif
} else {
trainer_id = *trainer_id_data;
}
printf("after get trainer_id %lu\n", trainer_id);
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
}
};
} // namespace operators
} // namespace paddle
......@@ -39,7 +39,8 @@ class SendBarrierOp : public framework::OperatorBase {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
VLOG(3) << "SendBarrierOp sync";
......@@ -67,6 +68,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
......
......@@ -44,7 +44,8 @@ class SendOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
......@@ -79,6 +80,7 @@ This operator will send variables to listen_and_serve op at the parameter server
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
......
......@@ -92,7 +92,7 @@ TEST(SendNcclId, RPCServer) {
std::string ep = string::Sprintf("127.0.0.1:%d", port);
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
......
......@@ -37,10 +37,15 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes.")
@staticmethod
def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers,
sync_mode):
def get_transpiler(trainer_id,
main_program,
pserver_endpoints,
trainers,
sync_mode,
dc_asgd=False):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id=trainer_id,
......@@ -55,7 +60,7 @@ class TestDistRunnerBase(object):
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode)
args.trainers, args.sync_mode, args.dc_asgd)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
......@@ -75,8 +80,7 @@ class TestDistRunnerBase(object):
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
args.sync_mode)
args.sync_mode, args.dc_asgd)
trainer_prog = t.get_trainer_program()
else:
trainer_prog = fluid.default_main_program()
......@@ -155,6 +159,7 @@ def runtime_main(test_class):
parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument(
'--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2)
......@@ -200,6 +205,7 @@ class TestDistBase(unittest.TestCase):
self._enforce_place = None
self._mem_opt = False
self._use_reduce = False
self._dc_asgd = False # must use with async mode
self._use_reader_alloc = True
self._setup_config()
self._after_setup_config()
......
......@@ -53,6 +53,15 @@ class TestDistMnistAsync(TestDistBase):
self.check_with_place("dist_mnist.py", delta=200)
class TestDistMnistDcAsgd(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._dc_asgd = True
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestRefByTrainerIdOp(OpTest):
def setUp(self):
self.op_type = "ref_by_trainer_id"
param_baks = [("x%d" % x, np.random.random((10, 10)).astype("float32"))
for x in range(10)]
self.inputs = {
'X': param_baks,
'TrainerId': np.array([8]).astype("int64")
}
self.outputs = {'Out': param_baks[8][1]}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
......@@ -38,7 +38,7 @@ import six
import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \
default_startup_program, Block, \
Parameter, grad_var_name
......@@ -138,6 +138,7 @@ class DistributeTranspilerConfig(object):
slice_var_up = True
split_method = None
min_block_size = 8192
enable_dc_asgd = False
# supported modes: pserver, nccl2
mode = "pserver"
print_log = False
......@@ -252,6 +253,8 @@ class DistributeTranspiler(object):
n workers, the id may range from 0 ~ n-1
program (Program|None): program to transpile,
default is fluid.default_main_program().
startup_program (Program|None): startup_program to transpile,
default is fluid.default_startup_program().
pservers (str): comma separated ip:port string for the pserver
list.
trainers (int|str): in pserver mode this is the number of
......@@ -383,6 +386,8 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -426,6 +431,7 @@ class DistributeTranspiler(object):
outputs={"Out": splited_var},
attrs={
"epmap": eps,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
......@@ -440,6 +446,7 @@ class DistributeTranspiler(object):
outputs={"Out": all_recv_outputs},
attrs={
"endpoints": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -651,6 +658,24 @@ in a single call.")
endpoint, op):
opt_op_on_pserver.append(op)
# step 3.3
# prepare if dc asgd is enabled
if self.config.enable_dc_asgd == True:
assert (self.sync_mode == False)
self.param_bak_list = []
# add param_bak for each trainer
for p in self.param_grad_ep_mapping[endpoint]["params"]:
# each parameter should have w_bak for each trainer id
for i in range(self.trainer_num):
param_bak_name = "%s.trainer_%d_bak" % (p.name, i)
tmpvar = pserver_program.global_block().create_var(
# NOTE: this var name format is used in `request_get_handler`
name=param_bak_name,
type=p.type,
shape=p.shape,
dtype=p.dtype)
self.param_bak_list.append((p, tmpvar))
# step 3.4
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
......@@ -741,7 +766,7 @@ in a single call.")
grad_to_block_id, merged_var,
lr_ops)
# dedup grad to ids list
# dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id))
# append global ops
if global_ops:
......@@ -787,6 +812,8 @@ in a single call.")
if self.has_distributed_lookup_table:
attrs['checkpint_block_id'] = checkpoint_block_id
if self.config.enable_dc_asgd:
attrs['dc_asgd'] = True
if len(prefetch_var_name_to_block_id) > 0:
attrs[
......@@ -903,6 +930,15 @@ to transpile() call.")
inputs=new_inputs,
outputs=new_outputs,
attrs=op.all_attrs())
if self.config.enable_dc_asgd:
for p, p_bak in self.param_bak_list:
startup_param_var = s_prog.global_block().vars[p.name]
startup_tmpvar = s_prog.global_block().vars[p_bak.name]
# copy init random value to param_bak
s_prog.global_block().append_op(
type="assign",
inputs={"X": startup_param_var},
outputs={"Out": startup_tmpvar})
# add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
......@@ -1175,6 +1211,7 @@ to transpile() call.")
attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[table_grad_name],
......@@ -1531,6 +1568,69 @@ to transpile() call.")
attrs={"scale": 1.0 / float(self.trainer_num)})
return merged_var
def _append_dc_asgd_ops(self, block, param_var, grad_var):
# NOTE: can not use grammar candy here, should put ops in specific block
local_param_bak = block.create_var(
name="%s.local_bak" % param_var.name,
shape=param_var.shape,
type=param_var.type,
dtype=param_var.dtype,
persistable=False)
# trainer_id_var is block local
trainer_id_var = block.create_var(
name="@TRAINER_ID@",
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=core.VarDesc.VarType.INT64,
shape=[1],
persistable=False)
# ref_inputs = [x[1] for x in self.param_bak_list]
ref_inputs = []
for p, p_bak in self.param_bak_list:
if p.name == param_var.name:
print("#### ref inputs: ", param_var.name, p_bak.name)
ref_inputs.append(p_bak)
block.append_op(
type="ref_by_trainer_id",
inputs={"X": ref_inputs,
"TrainerId": trainer_id_var},
outputs={"Out": local_param_bak})
def __create_temp_var__():
return block.create_var(
name=unique_name.generate("tmp_dc_output"),
shape=param_var.shape,
type=param_var.type,
dtype=param_var.dtype,
persistable=False)
o1 = __create_temp_var__()
block.append_op(
type="elementwise_sub",
inputs={"X": param_var,
"Y": local_param_bak},
outputs={"Out": o1})
o2 = __create_temp_var__()
block.append_op(
type="elementwise_mul",
inputs={"X": o1,
"Y": grad_var},
outputs={"Out": o2})
o3 = __create_temp_var__()
block.append_op(
type="elementwise_mul",
inputs={"X": o2,
"Y": grad_var},
outputs={"Out": o3})
# TODO(typhoonzero): append scale
o4 = __create_temp_var__()
block.append_op(
type="elementwise_add",
inputs={"X": grad_var,
"Y": o3},
outputs={"Out": o4})
return o4
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var):
program = optimize_block.program
......@@ -1546,9 +1646,16 @@ to transpile() call.")
break
return param_block
if self.config.enable_dc_asgd:
param_var = _get_param_block(opt_op)
dc = self._append_dc_asgd_ops(optimize_block, param_var, merged_var)
for key in opt_op.input_names:
if key == "Grad":
new_inputs[key] = merged_var
if self.config.enable_dc_asgd:
new_inputs[key] = dc
else:
new_inputs[key] = merged_var
elif key == "Param":
param_block = _get_param_block(opt_op)
if not param_block:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册