提交 8e48c77b 编写于 作者: Y yi.wu

wip

上级 3d69a82b
...@@ -136,7 +136,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -136,7 +136,7 @@ void ParallelExecutor::BCastParamsToGPUs(
// the the initializing bcast, all vars would be bcast from device(0), // the the initializing bcast, all vars would be bcast from device(0),
// otherwise // otherwise
// bcast from the specified device. // bcast from the specified device.
bool initializing = builder_.get() == nullptr ? false : true; bool initializing = builder_.get() == nullptr ? true : false;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = int var_dev_id =
...@@ -153,6 +153,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -153,6 +153,7 @@ void ParallelExecutor::BCastParamsToGPUs(
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
......
...@@ -302,7 +302,6 @@ class DistributeTranspiler(object): ...@@ -302,7 +302,6 @@ class DistributeTranspiler(object):
""" """
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
...@@ -383,11 +382,12 @@ class DistributeTranspiler(object): ...@@ -383,11 +382,12 @@ class DistributeTranspiler(object):
if self._is_adam_connected_op(op): if self._is_adam_connected_op(op):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var): def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
if self._is_optimizer_op(op): if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var) self.origin_program, merged_var)
else: elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op): def __op_have_grad_input__(op):
...@@ -452,7 +452,7 @@ class DistributeTranspiler(object): ...@@ -452,7 +452,7 @@ class DistributeTranspiler(object):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id, __append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var) merged_var, lr_ops)
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -461,7 +461,7 @@ class DistributeTranspiler(object): ...@@ -461,7 +461,7 @@ class DistributeTranspiler(object):
optimize_blocks.append(opt_state_block) optimize_blocks.append(opt_state_block)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None) grad_to_block_id, None, lr_ops)
# process distributed lookup_table # process distributed lookup_table
prefetch_var_name_to_block_id = [] prefetch_var_name_to_block_id = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册