未验证 提交 ac7f09a9 编写于 作者: S sneaxiy 提交者: GitHub

Make FLAGS_force_align_vpp_grad_sum_order default to false (#54937)

* make FLAGS_force_align_vpp_grad_sum_order default to false

* polish code
上级 7c764060
......@@ -49,7 +49,7 @@ class HybridParallelClipGrad:
self.not_sharding_stage1 = True
self._vpp_chunk_num = None
self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '1')
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0')
)
def _get_vpp_chunk_num(self, params_grads):
......@@ -168,9 +168,10 @@ class HybridParallelClipGrad:
@no_grad()
def _dygraph_clip(self, params_grads):
chunk_num = self._get_vpp_chunk_num(params_grads)
if chunk_num > 0 and self._force_align_vpp_grad_sum_order:
return self._vpp_dygraph_clip(params_grads, chunk_num)
if self._force_align_vpp_grad_sum_order:
chunk_num = self._get_vpp_chunk_num(params_grads)
if chunk_num > 0:
return self._vpp_dygraph_clip(params_grads, chunk_num)
sum_square_dist_fp16 = []
sum_square_dist_bf16 = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册