提交 cc6ef41d 编写于 作者: Q Qiao Longfei

update dist transpiler

上级 47280ef8
......@@ -236,6 +236,22 @@ class DistributeTranspiler(object):
else:
raise ValueError("must set trainer_id > 0")
def _get_all_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table"]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'is_sparse') is True and not op.attr('is_distributed'):
sparse_update_ops.append(op)
return sparse_update_ops
def _update_sparse_update_op(self, param_varname, height_sections,
endpint_map):
for op in self.sparse_update_ops:
if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map)
op._set_attr('height_sections', height_sections)
def transpile(self,
trainer_id,
program=None,
......@@ -299,6 +315,11 @@ class DistributeTranspiler(object):
self.param_name_to_grad_name[param_var.name] = grad_var.name
self.grad_name_to_param_name[grad_var.name] = param_var.name
# get all sparse update ops
self.sparse_update_ops = self._get_all_sparse_update_op(
self.origin_program)
self.sparse_param_to_height_sections = dict()
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
......@@ -425,6 +446,12 @@ class DistributeTranspiler(object):
if len(splited_trainer_grad) == 1:
recv_op_role_var_name = splited_trainer_grad[0].name
if param_varname in self.sparse_param_to_height_sections:
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_sparse_update_op(param_varname, height_sections,
eps)
else:
program.global_block().append_op(
type="recv",
inputs={"X": [recv_dep_in]},
......@@ -454,6 +481,9 @@ class DistributeTranspiler(object):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[param_varname]
print("sparse_param_to_height_sections: " + str(
self.sparse_param_to_height_sections))
if param_varname not in self.sparse_param_to_height_sections:
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
......@@ -1237,9 +1267,8 @@ to transpile() call.")
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if 'Param' in op.input_names and op.input("Param")[0] ==
self.table_name
op for op in self.optimize_ops if 'Param' in op.input_names and
op.input("Param")[0] == self.table_name
][0]
origin_param_var = self.origin_program.global_block().vars[
......@@ -1418,6 +1447,10 @@ to transpile() call.")
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
sparse_param_name = self.grad_name_to_param_name[orig_var.name]
if sparse_param_name != self.table_name:
self.sparse_param_to_height_sections[
sparse_param_name] = height_sections
program.global_block()._insert_op(
index=index + 1,
type="split_selected_rows",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册