未验证 提交 75d247b7 编写于 作者: W WangXi 提交者: GitHub

optimize grad add device (#33946)

上级 bd559a24
...@@ -474,11 +474,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): ...@@ -474,11 +474,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
continue continue
if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > 1:
if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
_accumulate_gradients_by_sum_op_( _accumulate_gradients_by_sum_op_(var_name, renamed_vars,
var_name, renamed_vars, pending_sum_ops, idx, op_device) pending_sum_ops, idx,
var_device[var_name])
else: else:
_accumulate_gradients_by_add_ops_( _accumulate_gradients_by_add_ops_(var_name, renamed_vars,
var_name, renamed_vars, pending_sum_ops, idx, op_device) pending_sum_ops, idx,
var_device[var_name])
for param_idx, param_name in enumerate(op_desc.output_names()): for param_idx, param_name in enumerate(op_desc.output_names()):
arg_names = op_desc.output(param_name) arg_names = op_desc.output(param_name)
...@@ -529,7 +531,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx): ...@@ -529,7 +531,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
arg_names[arg_idx] = new_name arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
# record the latest device, for shared param # record the latest device
var_device[var_name] = op_device var_device[var_name] = op_device
for var_name, inputs in six.iteritems(renamed_vars): for var_name, inputs in six.iteritems(renamed_vars):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册