未验证 提交 ac44d798 编写于 作者: L lzydev 提交者: GitHub

Fix `sharding_pass` and "nop" op to improve GC strategy (#56283)

* Improve GC for pipeline parallel

* Delete print

* fix bug of nop_op and sharding

---------
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 67ab0371
......@@ -45,14 +45,9 @@ establish the dependency between input and output tensors.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NopNoNeedBufferVarsInferer, "X", "Out");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(nop,
ops::NopOp,
ops::NopOpMaker,
ops::NopNoNeedBufferVarsInferer);
REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);
......@@ -295,12 +295,12 @@ class ShardingPass(PassBase):
self._insert_optimizer_broadcasts(main_block, startup_block)
def _shard_amp_related_op_and_vars(self, main_block):
if self.stage < 2:
if self.stage < 1:
return
for idx, op in reversed(list(enumerate(main_block.ops))):
# shard amp related param_grad cast
if _is_param_grad_fp32_cast_op(main_block, op):
if _is_param_grad_fp32_cast_op(main_block, op) and self.stage > 1:
output_name = op.output_arg_names[0]
param_name = output_name[: output_name.find("@")]
if not self._is_parameter_in_local_shard(param_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册