提交 8e48c77b 编写于 作者: Y yi.wu

wip

上级 3d69a82b
......@@ -136,7 +136,7 @@ void ParallelExecutor::BCastParamsToGPUs(
// the the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initializing = builder_.get() == nullptr ? false : true;
bool initializing = builder_.get() == nullptr ? true : false;
for (auto &var : vars) {
int var_dev_id =
......@@ -153,6 +153,7 @@ void ParallelExecutor::BCastParamsToGPUs(
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
......
......@@ -302,7 +302,6 @@ class DistributeTranspiler(object):
"""
# remove optimize ops and add a send op to main_program
delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__()
return self.origin_program
......@@ -383,11 +382,12 @@ class DistributeTranspiler(object):
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var)
else:
elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op):
......@@ -452,7 +452,7 @@ class DistributeTranspiler(object):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var)
merged_var, lr_ops)
# append global ops
if global_ops:
......@@ -461,7 +461,7 @@ class DistributeTranspiler(object):
optimize_blocks.append(opt_state_block)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None)
grad_to_block_id, None, lr_ops)
# process distributed lookup_table
prefetch_var_name_to_block_id = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册