未验证 提交 75d247b7 编写于 作者: W WangXi 提交者: GitHub

optimize grad add device (#33946)

上级 bd559a24
......@@ -474,11 +474,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
continue
if len(renamed_vars[var_name]) > 1:
if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
_accumulate_gradients_by_sum_op_(
var_name, renamed_vars, pending_sum_ops, idx, op_device)
_accumulate_gradients_by_sum_op_(var_name, renamed_vars,
pending_sum_ops, idx,
var_device[var_name])
else:
_accumulate_gradients_by_add_ops_(
var_name, renamed_vars, pending_sum_ops, idx, op_device)
_accumulate_gradients_by_add_ops_(var_name, renamed_vars,
pending_sum_ops, idx,
var_device[var_name])
for param_idx, param_name in enumerate(op_desc.output_names()):
arg_names = op_desc.output(param_name)
......@@ -529,7 +531,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
# record the latest device, for shared param
# record the latest device
var_device[var_name] = op_device
for var_name, inputs in six.iteritems(renamed_vars):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册