未验证 提交 c490f1b3 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8049 from typhoonzero/no_counter_on_pserver

Enhancement/transpiler rename grad vars to add trainer id, so RPC call can be retried.
......@@ -42,6 +42,25 @@ bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}
VarDesc *BlockDesc::RenameVar(const std::string &old_name,
const std::string &new_name) {
if (!this->HasVar(old_name)) {
return nullptr;
}
need_update_ = true;
auto *var = this->Var(old_name);
VarDesc *new_var = new VarDesc(*(var->Proto()));
new_var->SetName(new_name);
vars_[new_name].reset(new_var);
// rename inputs and outputs
for (const auto &op : ops_) {
auto *it = op.get();
it->Rename(old_name, new_name);
}
vars_.erase(old_name);
return new_var;
}
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr;
......
......@@ -55,6 +55,8 @@ class BlockDesc {
bool HasVar(const std::string &var_name) const;
VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
......
......@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join();
}
std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -91,8 +84,7 @@ class ListenAndServOp : public framework::OperatorBase {
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
......@@ -109,40 +101,24 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
// receive a variable
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto it =
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else {
VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
}
if (fan_in > 1 && !param_var_name.empty()) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
auto *var = recv_scope.FindVar(recv_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
......@@ -151,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
if (exit_flag) {
rpc_service_->ShutDown();
}
VLOG(3) << "run optimize graph...";
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size());
sparse_vars.clear();
} // while(true)
}
......@@ -181,13 +154,13 @@ class ListenAndServOp : public framework::OperatorBase {
protected:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
mutable std::unordered_map<std::string, int> grads_counter_;
};
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
......@@ -201,16 +174,7 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"GradList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
};
......
......@@ -170,6 +170,14 @@ void BindBlockDesc(py::module &m) {
[](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name;
return self.HasVar(name);
},
py::return_value_policy::reference)
.def("rename_var",
[](BlockDesc &self, const py::bytes &byte_name,
const py::bytes &byte_name_new) {
std::string name = byte_name;
std::string new_name = byte_name_new;
self.RenameVar(name, new_name);
})
.def("has_var_recursive",
[](BlockDesc &self, py::bytes byte_name) {
......@@ -198,7 +206,7 @@ void BindVarDsec(py::module &m) {
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
[](const VarDesc &self) {
[](VarDesc &self) {
py::bytes name = self.Name();
return name;
},
......
......@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
print "file md5", md5file(filename), md5sum
if retry < retry_limit:
retry += 1
else:
......
......@@ -286,6 +286,10 @@ class Variable(object):
def name(self):
return self.desc.name()
@name.setter
def name(self, new_name):
self.desc.set_name(new_name)
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
......@@ -531,6 +535,12 @@ class Operator(object):
"""
return self.desc.input(name)
def rename_input(self, old_name, new_name):
self.desc.rename_input(old_name, new_name)
def rename_output(self, old_name, new_name):
self.desc.rename_output(old_name, new_name)
@property
def input_names(self):
"""
......@@ -540,6 +550,14 @@ class Operator(object):
"""
return self.desc.input_names()
@property
def input_arg_names(self):
return self.desc.input_arg_names()
@property
def output_arg_names(self):
return self.desc.output_arg_names()
def output(self, name):
"""
Get output arguments by the output parameter name
......@@ -717,6 +735,60 @@ class Block(object):
def has_var(self, name):
return name in self.vars
def rename_var(self, name, new_name):
"""
Rename variable in vars and ops' inputs and outputs
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
v = self.var(name)
stop_gradient = None
trainable = None
optimize_attr = None
regularizer = None
gradient_clip_attr = None
error_clip = None
if type(v) == Parameter:
stop_gradient = v.stop_gradient
trainable = v.trainable
optimize_attr = v.optimize_attr
regularizer = v.regularizer
gradient_clip_attr = v.gradient_clip_attr
error_clip = v.error_clip
elif type(v) == Variable:
error_clip = v.error_clip
stop_gradient = v.stop_gradient
else:
raise ValueError("unsupported var type: %s", type(v))
self.desc.rename_var(name, new_name)
d = self.desc.find_var(new_name)
var = None
if type(v) == Parameter:
var = Parameter(
self,
d.shape(),
d.dtype(),
name=new_name,
stop_gradient=stop_gradient,
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
elif type(v) == Variable:
var = Variable(
self,
name=new_name,
error_clip=error_clip,
stop_gradient=stop_gradient)
# rename the python side, sync_with_cpp will only add
# new vars/ops to python side.
self.vars[new_name] = var
del self.vars[name]
self.sync_with_cpp()
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
param = Parameter(global_block, *args, **kwargs)
......
......@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
optimize_ops,
params_grads,
0,
pservers=pserver_endpoints,
trainers=trainers)
if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册