提交 c32040c3 编写于 作者: T typhoonzero

WIP: remove fan_in

上级 3f616152
...@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join(); server_thread_->join();
} }
std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -91,9 +84,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -91,9 +84,8 @@ class ListenAndServOp : public framework::OperatorBase {
// FIXME(Yancey1989): initialize rpc server with lazy mode. // FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto ins = Inputs("X");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto fan_in = ins.size();
auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
...@@ -109,35 +101,21 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -109,35 +101,21 @@ class ListenAndServOp : public framework::OperatorBase {
int batch_barrier = 0; int batch_barrier = 0;
while (batch_barrier != fan_in) { while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto recv_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit"; LOG(INFO) << "received terminate message and exit";
exit_flag = true; exit_flag = true;
break; break;
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) { } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message"; VLOG(3) << "recv batch barrier message";
batch_barrier++; batch_barrier++;
continue; continue;
} else { } else {
// receive a variable VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++; recv_var_cnt++;
auto it = auto *var = recv_scope.FindVar(recv_var_name);
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) { if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name; LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var"); PADDLE_THROW("Can not find server side var");
} }
detail::DeserializeFromMessage(v.second, dev_ctx, var); detail::DeserializeFromMessage(v.second, dev_ctx, var);
...@@ -171,6 +149,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -171,6 +149,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
ListenAndServ operator ListenAndServ operator
...@@ -184,17 +163,6 @@ from send_op and send back variables to recv_op. ...@@ -184,17 +163,6 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"GradList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
} }
}; };
......
...@@ -82,6 +82,7 @@ class DistributeTranspiler: ...@@ -82,6 +82,7 @@ class DistributeTranspiler:
def transpile(self, def transpile(self,
optimize_ops, optimize_ops,
params_grads, params_grads,
trainer_id,
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
...@@ -98,10 +99,19 @@ class DistributeTranspiler: ...@@ -98,10 +99,19 @@ class DistributeTranspiler:
:param optimize_ops: op list of optimization, should be the :param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize return value of Optimizer.minimize
:type optimize_ops: list :type optimize_ops: list
:param params_grads: list of tuple(weight, gradient)
:type params_grads: list
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to optimize, default is default_main_program :param program: program to optimize, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174" :param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string :type pservers: string
:return: return a list of programs :param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
""" """
assert (callable(split_method)) assert (callable(split_method))
if program is None: if program is None:
...@@ -109,6 +119,11 @@ class DistributeTranspiler: ...@@ -109,6 +119,11 @@ class DistributeTranspiler:
self.program = program self.program = program
self.trainers = trainers self.trainers = trainers
self.optimize_ops = optimize_ops self.optimize_ops = optimize_ops
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
# steps to transpile: # steps to transpile:
# 1. split variable to multiple blocks, aligned by product(dim[1:]) (width). # 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
# 2. modify trainer program add split_op to each Grad. # 2. modify trainer program add split_op to each Grad.
...@@ -189,10 +204,17 @@ class DistributeTranspiler: ...@@ -189,10 +204,17 @@ class DistributeTranspiler:
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
var_mapping[varname] = []
if len(splited) == 1: if len(splited) == 1:
var_mapping[varname] = [orig_var] # rename var to the trainer_id var
new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name)
var_mapping[varname] = \
[program.global_block().var(new_var_name)]
continue continue
var_mapping[varname] = []
orig_shape = orig_var.shape orig_shape = orig_var.shape
orig_dim1_flatten = 1 orig_dim1_flatten = 1
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
...@@ -205,11 +227,13 @@ class DistributeTranspiler: ...@@ -205,11 +227,13 @@ class DistributeTranspiler:
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d.trainer_%d" %
(varname, i, self.trainer_id),
psersistable=False, psersistable=False,
dtype=orig_var.dtype, dtype=orig_var.dtype,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
program.global_block().sync_with_cpp()
return var_mapping return var_mapping
def _clone_var(self, block, var): def _clone_var(self, block, var):
...@@ -449,6 +473,7 @@ class DistributeTranspiler: ...@@ -449,6 +473,7 @@ class DistributeTranspiler:
""" """
# step5 # step5
pserver_program = Program() pserver_program = Program()
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]: for v in self.param_grad_ep_mapping[endpoint]["grads"]:
...@@ -457,13 +482,19 @@ class DistributeTranspiler: ...@@ -457,13 +482,19 @@ class DistributeTranspiler:
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape) name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainers):
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
print("create variable for program: %s.trainer_%d" % print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id)) (orig_var_name, trainer_id))
pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=True, persistable=True,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
recv_inputs.append(var)
# step6 # step6
optimize_sub_program = Program() optimize_sub_program = Program()
# Iterate through the ops and append ops as needed # Iterate through the ops and append ops as needed
...@@ -481,20 +512,20 @@ class DistributeTranspiler: ...@@ -481,20 +512,20 @@ class DistributeTranspiler:
# Append the listen_and_serv op # Append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
inputs={}, inputs={'X': recv_inputs},
outputs={}, outputs={},
attrs={ attrs={
"OptimizeBlock": optimize_sub_program.global_block(), "OptimizeBlock": optimize_sub_program.global_block(),
"endpoint": endpoint, "endpoint": endpoint,
"ParamList": [ # "ParamList": [
p.name # p.name
for p in self.param_grad_ep_mapping[endpoint]["params"] # for p in self.param_grad_ep_mapping[endpoint]["params"]
], # ],
"GradList": [ # "GradList": [
p.name # p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"] # for p in self.param_grad_ep_mapping[endpoint]["grads"]
], # ],
"Fanin": self.trainers # "Fanin": self.trainers
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
......
...@@ -282,6 +282,10 @@ class Variable(object): ...@@ -282,6 +282,10 @@ class Variable(object):
def name(self): def name(self):
return self.desc.name() return self.desc.name()
@name.setter
def name(self, new_name):
self.desc.set_name(new_name)
@property @property
def shape(self): def shape(self):
# convert to tuple, make it as same as numpy API. # convert to tuple, make it as same as numpy API.
...@@ -530,6 +534,12 @@ class Operator(object): ...@@ -530,6 +534,12 @@ class Operator(object):
""" """
return self.desc.input(name) return self.desc.input(name)
def rename_input(self, old_name, new_name):
self.desc.rename_input(old_name, new_name)
def rename_output(self, old_name, new_name):
self.desc.rename_output(old_name, new_name)
@property @property
def input_names(self): def input_names(self):
""" """
...@@ -539,6 +549,14 @@ class Operator(object): ...@@ -539,6 +549,14 @@ class Operator(object):
""" """
return self.desc.input_names() return self.desc.input_names()
@property
def input_arg_names(self):
return self.desc.input_arg_names()
@property
def output_arg_names(self):
return self.desc.output_arg_names()
def output(self, name): def output(self, name):
""" """
Get output arguments by the output parameter name Get output arguments by the output parameter name
...@@ -716,6 +734,22 @@ class Block(object): ...@@ -716,6 +734,22 @@ class Block(object):
def has_var(self, name): def has_var(self, name):
return name in self.vars return name in self.vars
def rename_var(self, name, new_name):
"""
Rename variable in vars and ops' inputs and outputs
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
orig_var = self.var(name)
del self.vars[name]
orig_var.name = new_name
self.vars[new_name] = orig_var
for op in self.ops:
if name in op.input_arg_names:
op.rename_input(name, new_name)
if name in op.output_arg_names:
op.rename_output(name, new_name)
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
param = Parameter(global_block, *args, **kwargs) param = Parameter(global_block, *args, **kwargs)
...@@ -803,6 +837,7 @@ class Block(object): ...@@ -803,6 +837,7 @@ class Block(object):
for p in other.iter_parameters(): for p in other.iter_parameters():
assert isinstance(p, Parameter) assert isinstance(p, Parameter)
v = self.vars.get(p.name, None) v = self.vars.get(p.name, None)
print("var shape to copy", v)
if v is None: if v is None:
raise ValueError("copy_param_info_from should be invoked with " raise ValueError("copy_param_info_from should be invoked with "
"same topology") "same topology")
......
...@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count ...@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE", training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver "TRAINER") # get the training role: trainer/pserver
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers) optimize_ops,
params_grads,
0,
pservers=pserver_endpoints,
trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup) exe.run(pserver_startup)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册