提交 c32040c3 编写于 作者: T typhoonzero

WIP: remove fan_in

上级 3f616152
......@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join();
}
std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -91,9 +84,8 @@ class ListenAndServOp : public framework::OperatorBase {
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin");
auto ins = Inputs("X");
auto fan_in = ins.size();
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
......@@ -109,35 +101,21 @@ class ListenAndServOp : public framework::OperatorBase {
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
// receive a variable
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto it =
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
auto *var = recv_scope.FindVar(recv_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
......@@ -171,6 +149,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
......@@ -184,17 +163,6 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"GradList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
}
};
......
......@@ -82,6 +82,7 @@ class DistributeTranspiler:
def transpile(self,
optimize_ops,
params_grads,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
......@@ -98,10 +99,19 @@ class DistributeTranspiler:
:param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize
:type optimize_ops: list
:param params_grads: list of tuple(weight, gradient)
:type params_grads: list
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to optimize, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:return: return a list of programs
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
"""
assert (callable(split_method))
if program is None:
......@@ -109,6 +119,11 @@ class DistributeTranspiler:
self.program = program
self.trainers = trainers
self.optimize_ops = optimize_ops
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
# steps to transpile:
# 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
# 2. modify trainer program add split_op to each Grad.
......@@ -189,10 +204,17 @@ class DistributeTranspiler:
block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems():
orig_var = program.global_block().vars[varname]
var_mapping[varname] = []
if len(splited) == 1:
var_mapping[varname] = [orig_var]
# rename var to the trainer_id var
new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name)
var_mapping[varname] = \
[program.global_block().var(new_var_name)]
continue
var_mapping[varname] = []
orig_shape = orig_var.shape
orig_dim1_flatten = 1
if len(orig_shape) >= 2:
......@@ -205,11 +227,13 @@ class DistributeTranspiler:
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
var = program.global_block().create_var(
name="%s.block%d" % (varname, i),
name="%s.block%d.trainer_%d" %
(varname, i, self.trainer_id),
psersistable=False,
dtype=orig_var.dtype,
shape=splited_shape) # flattend splited var
var_mapping[varname].append(var)
program.global_block().sync_with_cpp()
return var_mapping
def _clone_var(self, block, var):
......@@ -449,6 +473,7 @@ class DistributeTranspiler:
"""
# step5
pserver_program = Program()
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
......@@ -457,13 +482,19 @@ class DistributeTranspiler:
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
(orig_var_name, trainer_id))
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
# step6
optimize_sub_program = Program()
# Iterate through the ops and append ops as needed
......@@ -481,20 +512,20 @@ class DistributeTranspiler:
# Append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={},
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": optimize_sub_program.global_block(),
"endpoint": endpoint,
"ParamList": [
p.name
for p in self.param_grad_ep_mapping[endpoint]["params"]
],
"GradList": [
p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"]
],
"Fanin": self.trainers
# "ParamList": [
# p.name
# for p in self.param_grad_ep_mapping[endpoint]["params"]
# ],
# "GradList": [
# p.name
# for p in self.param_grad_ep_mapping[endpoint]["grads"]
# ],
# "Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
......
......@@ -282,6 +282,10 @@ class Variable(object):
def name(self):
return self.desc.name()
@name.setter
def name(self, new_name):
self.desc.set_name(new_name)
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
......@@ -530,6 +534,12 @@ class Operator(object):
"""
return self.desc.input(name)
def rename_input(self, old_name, new_name):
self.desc.rename_input(old_name, new_name)
def rename_output(self, old_name, new_name):
self.desc.rename_output(old_name, new_name)
@property
def input_names(self):
"""
......@@ -539,6 +549,14 @@ class Operator(object):
"""
return self.desc.input_names()
@property
def input_arg_names(self):
return self.desc.input_arg_names()
@property
def output_arg_names(self):
return self.desc.output_arg_names()
def output(self, name):
"""
Get output arguments by the output parameter name
......@@ -716,6 +734,22 @@ class Block(object):
def has_var(self, name):
return name in self.vars
def rename_var(self, name, new_name):
"""
Rename variable in vars and ops' inputs and outputs
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
orig_var = self.var(name)
del self.vars[name]
orig_var.name = new_name
self.vars[new_name] = orig_var
for op in self.ops:
if name in op.input_arg_names:
op.rename_input(name, new_name)
if name in op.output_arg_names:
op.rename_output(name, new_name)
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
param = Parameter(global_block, *args, **kwargs)
......@@ -803,6 +837,7 @@ class Block(object):
for p in other.iter_parameters():
assert isinstance(p, Parameter)
v = self.vars.get(p.name, None)
print("var shape to copy", v)
if v is None:
raise ValueError("copy_param_info_from should be invoked with "
"same topology")
......
......@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
optimize_ops,
params_grads,
0,
pservers=pserver_endpoints,
trainers=trainers)
if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册