From e694d0c2e487a854103e0cc4796f92af6d27ccfd Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 4 Dec 2018 09:45:50 +0000 Subject: [PATCH] fix while_op eager deletion bug add unittest test=develop --- .../details/eager_deletion_op_handle.cc | 2 + paddle/fluid/framework/executor.cc | 2 +- .../fluid/operators/controlflow/while_op.cc | 84 +++++++++++++------ .../unittests/test_eager_deletion_mnist.py | 27 ++++++ .../test_eager_deletion_seresnext.py | 27 ++++++ .../test_eager_deletion_transformer.py | 27 ++++++ 6 files changed, 142 insertions(+), 27 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_seresnext.py create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index 41f616035d7..54715fed8d9 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -16,7 +16,9 @@ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" +#endif namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 5823f33034a..f443c2d8cf6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -101,7 +101,7 @@ static void DeleteUnusedTensors( if (--(it->second) == 0) { auto* var = scope.FindVar(name); if (var != nullptr) { - VLOG(10) << "Erase tensor \'" << name << "\'"; + VLOG(2) << "Erase tensor \'" << name << "\'"; if (var->IsType()) { erase_tensors.insert(var->GetMutable()); } else if (var->IsType()) { diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index d8410b40586..da7cad82d8d 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -32,6 +32,20 @@ static constexpr char kStepScopes[] = "StepScopes"; static constexpr char kX[] = "X"; static constexpr char kXGRAD[] = "X@GRAD"; static constexpr char kOutputs[] = "Out"; +static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; + +namespace { // NOLINT +static std::string GetSkipEagerDeletionVarsDebugString( + const std::vector &vars) { + std::string str = "Skip " + std::to_string(vars.size()) + + " var(s) in eager deletion mode: "; + for (auto &var : vars) { + str.append(var); + str.push_back(' '); + } + return str; +} +} // NOLINT class WhileOp : public framework::OperatorBase { public: @@ -59,21 +73,12 @@ class WhileOp : public framework::OperatorBase { "Condition of while op must in CPU memory."); bool is_test = Attr("is_test"); - auto &skip_eager_deletion_vars = - Attr>("skip_eager_deletion_vars"); - if (framework::GetEagerDeletionThreshold() >= 0 && VLOG_IS_ON(10)) { - std::string debug_string = - "Skip " + std::to_string(skip_eager_deletion_vars.size()) + - " vars in eager deletion mode: "; - for (auto &var : skip_eager_deletion_vars) { - debug_string.append(var); - debug_string.push_back(' '); - } - VLOG(10) << debug_string; + auto &skip_vars = Attr>(kSkipEagerDeletionVars); + if (framework::GetEagerDeletionThreshold() >= 0) { + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); } - auto ctx = - executor.Prepare(*program, block->ID(), skip_eager_deletion_vars); + auto ctx = executor.Prepare(*program, block->ID(), skip_vars); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); @@ -110,7 +115,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr>("skip_eager_deletion_vars", + AddAttr>(kSkipEagerDeletionVars, "Vars that would skip eager deletion." "Users should not set this manually.") .SetDefault(std::vector()); @@ -137,7 +142,12 @@ class WhileGradOp : public framework::OperatorBase { framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); - auto ctx = executor.Prepare(*program, block->ID()); + + auto &skip_vars = Attr>(kSkipEagerDeletionVars); + if (framework::GetEagerDeletionThreshold() >= 0) { + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); + } + auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto *step_scopes = scope.FindVar(Input(kStepScopes))->GetMutable(); @@ -359,29 +369,51 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { // while operator could be renamed. while_grad->SetAttr("original_output_grad", output_grads_list); - /* The following codes are used in eager deletion mode */ + /* The followi_ng codes are used in eager deletion mode */ + std::unordered_set bwd_skip_vars; if (framework::GetEagerDeletionThreshold() >= 0) { - std::unordered_set skip_vars; + std::unordered_set fwd_skip_vars; for (auto *op_desc : grad_block->AllOps()) { + auto skippable = [&](const std::string &name) { + return !grad_block->HasVar(name) && + (fwd_block->HasVarRecursive(name) || + parent_block->HasVarRecursive(name)); + }; for (auto &in_arg_name : op_desc->InputArgumentNames()) { - // If input var of ops inside grad_block is not from grad_block, - // it cannot be deleted when forward while_op runs - if (in_arg_name != framework::kEmptyVarName && - !grad_block->HasVar(in_arg_name)) { - skip_vars.insert(in_arg_name); + if (skippable(in_arg_name)) { + fwd_skip_vars.insert(in_arg_name); + } + } + + for (auto &out_arg_name : op_desc->OutputArgumentNames()) { + if (skippable(out_arg_name)) { + fwd_skip_vars.insert(out_arg_name); } } } - if (!skip_vars.empty()) { + if (!fwd_skip_vars.empty()) { // FIXME(zjl): ugly const_cast here, maybe we should find a better way // to modify forward while_op auto &fwd_while_op = const_cast(ForwardOp()); - fwd_while_op.SetAttr( - "skip_eager_deletion_vars", - std::vector(skip_vars.begin(), skip_vars.end())); + fwd_while_op.SetAttr(kSkipEagerDeletionVars, + std::vector(fwd_skip_vars.begin(), + fwd_skip_vars.end())); + } + + // Find backward skip vars + auto fwd_input = Input(kX); + for (size_t i = 0; i < igs.size(); ++i) { + if (igs[i] == framework::kEmptyVarName) { + continue; + } + bwd_skip_vars.insert(igs[i]); + bwd_skip_vars.insert(framework::GradVarName(fwd_input[i])); } } + while_grad->SetAttr( + kSkipEagerDeletionVars, + std::vector(bwd_skip_vars.begin(), bwd_skip_vars.end())); return std::unique_ptr(while_grad); } diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py new file mode 100644 index 00000000000..7ec1f0ae753 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" + +from test_parallel_executor_mnist import TestMNIST + + +class EagerDeletionTestMNIST(TestMNIST): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_seresnext.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_seresnext.py new file mode 100644 index 00000000000..2dcdbdb8f13 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_seresnext.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" + +from test_parallel_executor_seresnext import TestResnet + + +class EagerDeletionTestSEResNext(TestResnet): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py new file mode 100644 index 00000000000..754d5fd4095 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" + +from test_parallel_executor_transformer import TestTransformer + + +class EagerDeletionTestTransformer(TestTransformer): + pass + + +if __name__ == '__main__': + unittest.main() -- GitLab