未验证 提交 279aa626 编写于 作者: Y Yancey 提交者: GitHub

Move learning rate and releated op to pserver (#8209)

* dist train support lr decay

* update by comment

* revert elementwise method creator

* delete comment
上级 72bcf72c
...@@ -106,6 +106,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -106,6 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
size_t recv_var_cnt = 0; size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0; int batch_barrier = 0;
while (batch_barrier != fan_in) { while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
...@@ -126,13 +127,14 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -126,13 +127,14 @@ class ListenAndServOp : public framework::OperatorBase {
std::string param_var_name; std::string param_var_name;
if (it != grad_list.end()) { if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()]; param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else { } else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name; VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
} }
VLOG(3) << "received grad: " << grad_var_name if (fan_in > 1 && !param_var_name.empty()) {
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
} }
auto *var = recv_scope.FindVar(grad_var_name); auto *var = recv_scope.FindVar(grad_var_name);
...@@ -144,11 +146,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -144,11 +146,10 @@ class ListenAndServOp : public framework::OperatorBase {
} }
} }
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
} }
VLOG(3) << "run optimize graph...";
try { try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
...@@ -156,7 +157,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -156,7 +157,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(recv_var_cnt); rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }
......
...@@ -33,6 +33,57 @@ class VarBlock: ...@@ -33,6 +33,57 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size) return "%s:%d:%d" % (self.varname, self.offset, self.size)
class UnionFind(object):
""" Union-find data struct.
Union-find is a data struct that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
def same_or_split_var(p_name, var_name): def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
...@@ -178,6 +229,21 @@ class DistributeTranspiler: ...@@ -178,6 +229,21 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
self.lr_param_mapping = self._create_lr_param_mapping()
def _create_lr_param_mapping(self):
lr_mapping = dict()
for _, opt_op in enumerate(self.optimize_ops):
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
or not opt_op.inputs.has_key("Param"):
continue
lr = opt_op.inputs["LearningRate"].name
param = opt_op.inputs["Param"].name
if not lr_mapping.has_key(lr):
lr_mapping.update({lr: list()})
lr_mapping[lr].append(param)
return lr_mapping
def _create_vars_from_blocklist(self, program, block_list): def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list # Create respective variables using the block_list
block_map = dict() block_map = dict()
...@@ -300,52 +366,15 @@ class DistributeTranspiler: ...@@ -300,52 +366,15 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _op_input_var(self, op, varname): def _fetch_var_names(self, param_dict):
pass res = []
if not param_dict:
def _is_op_on_pserver(self, endpoint, all_ops, idx): return res
""" for _, values in param_dict.iteritems():
Recursively check if the op need to run on current server. if not isinstance(values, list):
Assume that ops are in the execution order. values = [values]
""" res += [v.name for v in values]
param_names = [ return res
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
input_names = set(op.input_names)
# TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True
else:
for n in param_names:
if same_or_split_var(n, op.input("Param")[0]) \
and n != op.input("Param")[0]:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
# prev_output_names = [o.name for o in prev_op.outputs.values()]
# prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False
found2 = False
for varname in op.desc.input_arg_names():
if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for varname in op.desc.output_arg_names():
if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
j -= 1
return False
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
...@@ -363,11 +392,7 @@ class DistributeTranspiler: ...@@ -363,11 +392,7 @@ class DistributeTranspiler:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
merged_var = program.global_block().create_var( merged_var = program.global_block().vars[grad_block.name]
name=grad_block.name,
persistable=grad_block.persistable,
dtype=grad_block.dtype,
shape=grad_block.shape)
# append merging ops if trainers > 1 # append merging ops if trainers > 1
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = self._create_var_for_trainers(
...@@ -398,13 +423,19 @@ class DistributeTranspiler: ...@@ -398,13 +423,19 @@ class DistributeTranspiler:
shape=param_block.shape) shape=param_block.shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
new_inputs[key] = program.global_block().vars[opt_op.input(key)[
0]]
for key in opt_op.input_names: for key in opt_op.input_names:
if key in ["Param", "Grad"]: new_shape = None
if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = program.global_block().create_var(
...@@ -415,12 +446,11 @@ class DistributeTranspiler: ...@@ -415,12 +446,11 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# change output's ParamOut variable # change output's ParamOut variable
outputs = self._get_output_map_from_op(program.global_block(), opt_op) opt_op.outputs["ParamOut"] = new_inputs["Param"]
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
...@@ -428,11 +458,10 @@ class DistributeTranspiler: ...@@ -428,11 +458,10 @@ class DistributeTranspiler:
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars, inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op) opt_op)
for var in inputs.itervalues(): for varlist in inputs.itervalues():
if type(var) == list: if not isinstance(varlist, list):
varlist = var varlist = [varlist]
else:
varlist = [var]
for var in varlist: for var in varlist:
if not program.global_block().vars.has_key(var.name): if not program.global_block().vars.has_key(var.name):
program.global_block().create_var( program.global_block().create_var(
...@@ -444,12 +473,70 @@ class DistributeTranspiler: ...@@ -444,12 +473,70 @@ class DistributeTranspiler:
outputs = self._get_output_map_from_op(self.program.global_block().vars, outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op) opt_op)
for varlist in outputs.itervalues():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _is_op_connected(self, op1, op2):
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
op1_input_names = self._fetch_var_names(op1.inputs)
op1_output_names = self._fetch_var_names(op1.outputs)
op2_input_names = self._fetch_var_names(op2.inputs)
op2_output_names = self._fetch_var_names(op2.outputs)
if set(op1_output_names) & set(op2_input_names) or \
set(op1_input_names) & set(op2_output_names):
return True
return False
def _create_ufind(self, optimize_ops):
# Create a unit find data struct by optimize ops
ufind = UnionFind(optimize_ops)
for i in xrange(len(optimize_ops)):
for j in xrange(i, len(optimize_ops)):
op1 = optimize_ops[i]
op2 = optimize_ops[j]
if self._is_op_connected(op1, op2):
ufind.union(op1, op2)
return ufind
def _is_opt_op(self, op):
# NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if op.inputs and op.inputs.has_key("Param") \
and op.inputs.has_key("LearningRate"):
return True
return False
def _is_opt_op_on_pserver(self, endpoint, op):
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
if op.inputs["Param"].name in param_names:
return True
else:
for n in param_names:
param = op.inputs["Param"].name
if same_or_split_var(n, param) and n != op.inputs["Param"].name:
return True
return False
return False
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
""" """
Get pserver side program using the endpoint Get pserver side program using the endpoint
...@@ -469,8 +556,6 @@ class DistributeTranspiler: ...@@ -469,8 +556,6 @@ class DistributeTranspiler:
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape) name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id), name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True, persistable=True,
...@@ -478,17 +563,30 @@ class DistributeTranspiler: ...@@ -478,17 +563,30 @@ class DistributeTranspiler:
shape=v.shape) shape=v.shape)
# step6 # step6
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# Iterate through the ops and append ops as needed # step 6.1
for idx, opt_op in enumerate(self.optimize_ops): # Create a union-find data struct by optimize ops,
is_op_on_pserver = self._is_op_on_pserver(endpoint, # If two ops are connected, we could add these two ops
self.optimize_ops, idx) # into one set.
if not is_op_on_pserver: ufind = self._create_ufind(self.optimize_ops)
continue # step 6.2
if "Grad" in opt_op.desc.input_arg_names(): # Iterate through the ops and append optimize op which
self._append_pserver_ops(optimize_block, opt_op, endpoint) # located on current pserver
else: opt_op_on_pserver = []
self._append_pserver_non_opt_ops(optimize_block, opt_op) for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op)
# step 6.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
# Append the listen_and_serv op # Append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
......
...@@ -117,6 +117,7 @@ def monkey_patch_variable(): ...@@ -117,6 +117,7 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op( self.block.append_op(
type=op_type, type=op_type,
inputs={'X': [self], inputs={'X': [self],
......
...@@ -99,7 +99,7 @@ elif training_role == "TRAINER": ...@@ -99,7 +99,7 @@ elif training_role == "TRAINER":
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
avg_cost_np = exe.run(fluid.default_main_program(), avg_cost_np = exe.run(t.get_trainer_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost]) fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np) print("avg_cost_np", avg_cost_np)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册