diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 4a41cd04cc2f37225be60131f6f80658b714f76d..46d19590e98ab48ef053a5d8623701f74adcaefa 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -149,7 +149,7 @@ void AppendSkipDeletionVars(const std::vector &append_vars, } std::set ParseSafeEagerDeletionSkipVarsSet( - const ProgramDesc &backward_program) { + const ProgramDesc &backward_program, bool skip_no_need_buffer) { std::set skip_eager_delete_vars; auto backward_ops = backward_program.Block(0).AllOps(); auto &op_info_map = OpInfoMap::Instance(); @@ -158,6 +158,7 @@ std::set ParseSafeEagerDeletionSkipVarsSet( std::unordered_set no_need_buffer_ins; for (size_t i = 0; i < backward_ops.size(); ++i) { framework::OpDesc *op = backward_ops[i]; + VLOG(4) << "parse op type: " << op->Type(); if (op->Type() == "share_buffer") { VLOG(1) << "skip share_buffer op"; continue; @@ -166,7 +167,9 @@ std::set ParseSafeEagerDeletionSkipVarsSet( auto &op_info = op_info_map.Get(op->Type()); auto &inferer = op_info.NoNeedBufferVarsInferer(); no_need_buffer_ins.clear(); - if (inferer != nullptr) { + // TODO(Aurelius84): Need remove skip_no_need_buffer after cinn fix this + // problem. + if (inferer != nullptr && !skip_no_need_buffer) { no_need_buffer_ins = inferer(op->Inputs(), op->Outputs(), op->GetAttrMap()); } @@ -185,6 +188,7 @@ std::set ParseSafeEagerDeletionSkipVarsSet( } } for (const std::string &var_name : op_inputs) { + VLOG(4) << "parse op.input: " << var_name; if (op_outputs.find(var_name) == op_outputs.end()) { VLOG(1) << "skip eager var: " << var_name; skip_eager_delete_vars.insert(var_name); diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 196bfd22b1e3d2c4a1ba634514390d2d59c04873..420ccf4ee84c4d701741766536163379d62c81b3 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -49,8 +49,10 @@ void ParseSafeEagerDeletionSkipVars( void AppendSkipDeletionVars(const std::vector& append_vars, std::set* all_vars); +// TODO(Aurelius84) : Need remove skip_no_need_buffer after cinn fix this +// problem. std::set ParseSafeEagerDeletionSkipVarsSet( - const ProgramDesc& backward_program); + const ProgramDesc& backward_program, bool skip_no_need_buffer = false); } // namespace details diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd66156aa5f5ab6c1a1253fc6315c132ad69e36 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class Net(paddle.nn.Layer): + def __init__(self): + super(Net, self).__init__() + self.relu = paddle.nn.functional.relu + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + y = paddle.full_like(x, 1.0) + y.stop_gradient = False + z = self.fc(x) * y + out = y + z + out = self.relu(out) + + return out + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class TestCINN(unittest.TestCase): + def setUp(self): + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + + def train(self, use_cinn): + paddle.seed(2022) + net = Net() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + if use_cinn: + net = apply_to_static(net, use_cinn) + + res = [] + for step in range(10): + out = net(self.x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + return res + + def test_cinn(self): + dy_res = self.train(use_cinn=False) + cinn_res = self.train(use_cinn=True) + + for i in range(len(dy_res)): + np.testing.assert_array_equal(cinn_res[i], dy_res[i]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index d9cc9c390dd2f03bd6b52ee420f03303b238eb7d..d2c3d25423a1f97a64dddb211cfd9f428b2164a3 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -409,7 +409,7 @@ class PartialProgramLayer: for param in self._params: candidate = [ var_name - for var_name in self.backward_program.block(0).vars.keys() + for var_name in self._train_program.block(0).vars.keys() if var_name.endswith(param.name + '@GRAD') ] if candidate: @@ -753,7 +753,11 @@ class PartialProgramLayer: self._outputs.var_ids ) backward_end_op_index = whole_program.desc.block(0).op_size() - backward_skip_vars = self._parse_skip_gc_vars(whole_program) + # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because + # they will be shared in scope and used by optimizer. + backward_skip_vars = ( + self._parse_skip_gc_vars(whole_program) + self._param_grad_names + ) backward_builded_program = add_build_strategy_for( whole_program, backward_start_op_index, @@ -843,7 +847,7 @@ class PartialProgramLayer: if backward_program: for var_name in core.parse_safe_eager_deletion_skip_vars( - backward_program.desc + backward_program.desc, True ): skip_vars.append(var_name) return skip_vars