“cc9a860d930c7e336128e1020d29a1d5219811b5”上不存在“develop/doc/design/fluid_compiler.html”
未验证 提交 ac44d798 编写于 作者: L lzydev 提交者: GitHub

Fix `sharding_pass` and "nop" op to improve GC strategy (#56283)

* Improve GC for pipeline parallel

* Delete print

* fix bug of nop_op and sharding

---------
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 67ab0371
...@@ -45,14 +45,9 @@ establish the dependency between input and output tensors. ...@@ -45,14 +45,9 @@ establish the dependency between input and output tensors.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NopNoNeedBufferVarsInferer, "X", "Out");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(nop, REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);
ops::NopOp,
ops::NopOpMaker,
ops::NopNoNeedBufferVarsInferer);
...@@ -295,12 +295,12 @@ class ShardingPass(PassBase): ...@@ -295,12 +295,12 @@ class ShardingPass(PassBase):
self._insert_optimizer_broadcasts(main_block, startup_block) self._insert_optimizer_broadcasts(main_block, startup_block)
def _shard_amp_related_op_and_vars(self, main_block): def _shard_amp_related_op_and_vars(self, main_block):
if self.stage < 2: if self.stage < 1:
return return
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
# shard amp related param_grad cast # shard amp related param_grad cast
if _is_param_grad_fp32_cast_op(main_block, op): if _is_param_grad_fp32_cast_op(main_block, op) and self.stage > 1:
output_name = op.output_arg_names[0] output_name = op.output_arg_names[0]
param_name = output_name[: output_name.find("@")] param_name = output_name[: output_name.find("@")]
if not self._is_parameter_in_local_shard(param_name): if not self._is_parameter_in_local_shard(param_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册