未验证 提交 0400eaed 编写于 作者: A Aurelius84 提交者: GitHub

[D2SCinn]Add test_cinn unittest and param_grad into skip_gc_vars (#49575)

* [D2SCinn]Add test_cinn unittest and param_grad into skip_gc_vars

* remove print
上级 e4c438f5
...@@ -149,7 +149,7 @@ void AppendSkipDeletionVars(const std::vector<std::string> &append_vars, ...@@ -149,7 +149,7 @@ void AppendSkipDeletionVars(const std::vector<std::string> &append_vars,
} }
std::set<std::string> ParseSafeEagerDeletionSkipVarsSet( std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
const ProgramDesc &backward_program) { const ProgramDesc &backward_program, bool skip_no_need_buffer) {
std::set<std::string> skip_eager_delete_vars; std::set<std::string> skip_eager_delete_vars;
auto backward_ops = backward_program.Block(0).AllOps(); auto backward_ops = backward_program.Block(0).AllOps();
auto &op_info_map = OpInfoMap::Instance(); auto &op_info_map = OpInfoMap::Instance();
...@@ -158,6 +158,7 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet( ...@@ -158,6 +158,7 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
std::unordered_set<std::string> no_need_buffer_ins; std::unordered_set<std::string> no_need_buffer_ins;
for (size_t i = 0; i < backward_ops.size(); ++i) { for (size_t i = 0; i < backward_ops.size(); ++i) {
framework::OpDesc *op = backward_ops[i]; framework::OpDesc *op = backward_ops[i];
VLOG(4) << "parse op type: " << op->Type();
if (op->Type() == "share_buffer") { if (op->Type() == "share_buffer") {
VLOG(1) << "skip share_buffer op"; VLOG(1) << "skip share_buffer op";
continue; continue;
...@@ -166,7 +167,9 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet( ...@@ -166,7 +167,9 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
auto &op_info = op_info_map.Get(op->Type()); auto &op_info = op_info_map.Get(op->Type());
auto &inferer = op_info.NoNeedBufferVarsInferer(); auto &inferer = op_info.NoNeedBufferVarsInferer();
no_need_buffer_ins.clear(); no_need_buffer_ins.clear();
if (inferer != nullptr) { // TODO(Aurelius84): Need remove skip_no_need_buffer after cinn fix this
// problem.
if (inferer != nullptr && !skip_no_need_buffer) {
no_need_buffer_ins = no_need_buffer_ins =
inferer(op->Inputs(), op->Outputs(), op->GetAttrMap()); inferer(op->Inputs(), op->Outputs(), op->GetAttrMap());
} }
...@@ -185,6 +188,7 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet( ...@@ -185,6 +188,7 @@ std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
} }
} }
for (const std::string &var_name : op_inputs) { for (const std::string &var_name : op_inputs) {
VLOG(4) << "parse op.input: " << var_name;
if (op_outputs.find(var_name) == op_outputs.end()) { if (op_outputs.find(var_name) == op_outputs.end()) {
VLOG(1) << "skip eager var: " << var_name; VLOG(1) << "skip eager var: " << var_name;
skip_eager_delete_vars.insert(var_name); skip_eager_delete_vars.insert(var_name);
......
...@@ -49,8 +49,10 @@ void ParseSafeEagerDeletionSkipVars( ...@@ -49,8 +49,10 @@ void ParseSafeEagerDeletionSkipVars(
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars, void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
std::set<std::string>* all_vars); std::set<std::string>* all_vars);
// TODO(Aurelius84) : Need remove skip_no_need_buffer after cinn fix this
// problem.
std::set<std::string> ParseSafeEagerDeletionSkipVarsSet( std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
const ProgramDesc& backward_program); const ProgramDesc& backward_program, bool skip_no_need_buffer = false);
} // namespace details } // namespace details
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.relu = paddle.nn.functional.relu
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
y = paddle.full_like(x, 1.0)
y.stop_gradient = False
z = self.fc(x) * y
out = y + z
out = self.relu(out)
return out
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class TestCINN(unittest.TestCase):
def setUp(self):
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_cinn):
paddle.seed(2022)
net = Net()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_cinn:
net = apply_to_static(net, use_cinn)
res = []
for step in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
return res
def test_cinn(self):
dy_res = self.train(use_cinn=False)
cinn_res = self.train(use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_array_equal(cinn_res[i], dy_res[i])
if __name__ == '__main__':
unittest.main()
...@@ -409,7 +409,7 @@ class PartialProgramLayer: ...@@ -409,7 +409,7 @@ class PartialProgramLayer:
for param in self._params: for param in self._params:
candidate = [ candidate = [
var_name var_name
for var_name in self.backward_program.block(0).vars.keys() for var_name in self._train_program.block(0).vars.keys()
if var_name.endswith(param.name + '@GRAD') if var_name.endswith(param.name + '@GRAD')
] ]
if candidate: if candidate:
...@@ -753,7 +753,11 @@ class PartialProgramLayer: ...@@ -753,7 +753,11 @@ class PartialProgramLayer:
self._outputs.var_ids self._outputs.var_ids
) )
backward_end_op_index = whole_program.desc.block(0).op_size() backward_end_op_index = whole_program.desc.block(0).op_size()
backward_skip_vars = self._parse_skip_gc_vars(whole_program) # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
# they will be shared in scope and used by optimizer.
backward_skip_vars = (
self._parse_skip_gc_vars(whole_program) + self._param_grad_names
)
backward_builded_program = add_build_strategy_for( backward_builded_program = add_build_strategy_for(
whole_program, whole_program,
backward_start_op_index, backward_start_op_index,
...@@ -843,7 +847,7 @@ class PartialProgramLayer: ...@@ -843,7 +847,7 @@ class PartialProgramLayer:
if backward_program: if backward_program:
for var_name in core.parse_safe_eager_deletion_skip_vars( for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc backward_program.desc, True
): ):
skip_vars.append(var_name) skip_vars.append(var_name)
return skip_vars return skip_vars
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册