提交 02ea3491 编写于 作者: T typhoonzero

enhance dist train performance

上级 bcfb82d3
......@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(var_name);
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);
// varhandle
VarHandle var_h;
var_h.ep = ep;
......@@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}
bool RPCClient::wait() {
bool RPCClient::Wait() {
bool ok = true;
while (true) {
......
......@@ -130,7 +130,7 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool wait();
bool Wait();
private:
bool Proceed();
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
......@@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
char ret[256];
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
grads_counter_[varname]++);
return std::string(ret);
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
rpc_service_->SetScope(&recv_scope);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client.
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
for (size_t i = 0; i < param_count * fan_in; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
......@@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
<< "\"";
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
}
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
// Assume grad_var_name must appear in global scope.
std::string grad_var_name_trainer;
if (fan_in > 1) {
grad_var_name_trainer = this->GetGradVarNameForTrainer(grad_var_name);
}
if (trainer_count > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
auto *var = recv_scope.FindVar(grad_var_name_trainer);
if (var == nullptr) {
LOG(ERROR) << "can not find server side var: "
<< grad_var_name_trainer;
PADDLE_THROW("can not find server side var");
}
auto *var = recv_scope.Var(grad_var_name);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
if (exit_flag) {
break;
}
rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
......@@ -195,7 +181,7 @@ This operator will recv tensor from send_op
"GradList", "type list of string",
"grad->param name mapping to find which param to optimize.")
.SetDefault({});
AddAttr<int>("Trainers", "type int",
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
}
......
......@@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase {
// FIXME(gongwb): DeviceContext?
auto ctx = platform::CPUDeviceContext();
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
client_.Wait();
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
client_.wait();
client_.Wait();
}
private:
......
......@@ -452,6 +452,19 @@ class DistributeTranspiler:
pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(optimize_ops):
......@@ -481,7 +494,7 @@ class DistributeTranspiler:
p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"]
],
"Trainers": self.trainers
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
......
......@@ -39,26 +39,27 @@ train_reader = paddle.batch(
place = fluid.CPUPlace()
exe = fluid.Executor(place)
t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
exe.run(fluid.default_startup_program())
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册