未验证 提交 dd7a48bd 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #10123 from jacquesqiao/split-optimize-op-into-signle-blocks

split optimization ops on pserver to independenty blocks
...@@ -368,21 +368,19 @@ class DistributeTranspiler: ...@@ -368,21 +368,19 @@ class DistributeTranspiler:
else: else:
recv_inputs.append(single_trainer_var) recv_inputs.append(single_trainer_var)
# step3 # step 3
optimize_block = pserver_program.create_block(0)
# step 4
# Create a union-find data structure from optimize ops, # Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops # If two ops are connected, we could add these two ops
# into one set. # into one set.
ufind = self._create_ufind(self.optimize_ops) ufind = self._create_ufind(self.optimize_ops)
# step 4.2 # step 3.2
# Iterate through the ops and append optimize op which # Iterate through the ops and append optimize op which
# located on current pserver # located on current pserver
opt_op_on_pserver = [] opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op): if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op) opt_op_on_pserver.append(op)
# step 4.3 # step 3.3
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
...@@ -415,29 +413,30 @@ class DistributeTranspiler: ...@@ -415,29 +413,30 @@ class DistributeTranspiler:
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
append_block = optimize_block
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(append_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
append_block = pserver_program.create_block(append_block.idx)
# append op to the current block # append op to the current block
per_opt_block = append_block pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \ if ufind.is_connected(op, opt_op) and op not in global_ops:
op not in global_ops:
__append_optimize_op__(op, per_opt_block) __append_optimize_op__(op, per_opt_block)
if idx == len(opt_op_on_pserver) - 1 and global_ops:
per_opt_block = pserver_program.create_block(append_block.idx)
# append global ops # append global ops
for glb_op in global_ops: opt_state_block = None
__append_optimize_op__(glb_op, per_opt_block) if global_ops:
opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block)
# NOT USED: single block version: # NOT USED: single block version:
# #
...@@ -451,10 +450,10 @@ class DistributeTranspiler: ...@@ -451,10 +450,10 @@ class DistributeTranspiler:
prefetch_block = None prefetch_block = None
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
self._create_table_optimize_block(pserver_index, pserver_program, table_opt_block = self._create_table_optimize_block(
append_block) pserver_index, pserver_program, pre_block_idx)
prefetch_block = self._create_prefetch_block( prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, optimize_block) pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
...@@ -470,7 +469,7 @@ class DistributeTranspiler: ...@@ -470,7 +469,7 @@ class DistributeTranspiler:
inputs={'X': recv_inputs}, inputs={'X': recv_inputs},
outputs={}, outputs={},
attrs={ attrs={
"OptimizeBlock": optimize_block, "OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block "PrefetchBlock": prefetch_block
...@@ -663,7 +662,7 @@ class DistributeTranspiler: ...@@ -663,7 +662,7 @@ class DistributeTranspiler:
return prefetch_block return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
append_block): pre_block_idx):
def _clone_var(block, var, persistable=True): def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( return block.create_var(
...@@ -700,7 +699,7 @@ class DistributeTranspiler: ...@@ -700,7 +699,7 @@ class DistributeTranspiler:
op for op in self.optimize_ops op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name if op.input("Param")[0] == self.table_name
][0] ][0]
table_opt_block = pserver_program.create_block(append_block.idx) table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now # only support sgd now
assert table_opt_op.type == "sgd" assert table_opt_op.type == "sgd"
...@@ -724,6 +723,8 @@ class DistributeTranspiler: ...@@ -724,6 +723,8 @@ class DistributeTranspiler:
outputs=outputs, outputs=outputs,
attrs=table_opt_op.attrs) attrs=table_opt_op.attrs)
return table_opt_block
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
......
...@@ -1107,6 +1107,10 @@ class Program(object): ...@@ -1107,6 +1107,10 @@ class Program(object):
def random_seed(self): def random_seed(self):
return self._seed return self._seed
@property
def num_blocks(self):
return self.desc.num_blocks()
@random_seed.setter @random_seed.setter
def random_seed(self, seed): def random_seed(self, seed):
if not isinstance(seed, int): if not isinstance(seed, int):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册