未验证 提交 0fe2001a 编写于 作者: L Leo Chen 提交者: GitHub

make variable 'gradient_merge_cond' local (#41262)

上级 acec26a1
......@@ -1621,13 +1621,8 @@ class ShardingOptimizer(MetaOptimizerBase):
persistable=True,
force_cpu=True)
cond_var = layers.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=False,
force_cpu=True)
cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool')
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
......
......@@ -107,13 +107,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
force_cpu=True)
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
cond_var = layers.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=False,
force_cpu=True)
cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool')
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"):
......
......@@ -7098,13 +7098,8 @@ class GradientMergeOptimizer(object):
persistable=True,
force_cpu=True)
cond_var = layers.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=False,
force_cpu=True)
cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool')
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册