提交 e694d0c2 编写于 作者: S sneaxiy

fix while_op eager deletion bug

add unittest
test=develop
上级 35a25784
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -101,7 +101,7 @@ static void DeleteUnusedTensors( ...@@ -101,7 +101,7 @@ static void DeleteUnusedTensors(
if (--(it->second) == 0) { if (--(it->second) == 0) {
auto* var = scope.FindVar(name); auto* var = scope.FindVar(name);
if (var != nullptr) { if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'"; VLOG(2) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>()); erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
......
...@@ -32,6 +32,20 @@ static constexpr char kStepScopes[] = "StepScopes"; ...@@ -32,6 +32,20 @@ static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X"; static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD"; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out"; static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) {
std::string str = "Skip " + std::to_string(vars.size()) +
" var(s) in eager deletion mode: ";
for (auto &var : vars) {
str.append(var);
str.push_back(' ');
}
return str;
}
} // NOLINT
class WhileOp : public framework::OperatorBase { class WhileOp : public framework::OperatorBase {
public: public:
...@@ -59,21 +73,12 @@ class WhileOp : public framework::OperatorBase { ...@@ -59,21 +73,12 @@ class WhileOp : public framework::OperatorBase {
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test"); bool is_test = Attr<bool>("is_test");
auto &skip_eager_deletion_vars = auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
Attr<std::vector<std::string>>("skip_eager_deletion_vars"); if (framework::GetEagerDeletionThreshold() >= 0) {
if (framework::GetEagerDeletionThreshold() >= 0 && VLOG_IS_ON(10)) { VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
std::string debug_string =
"Skip " + std::to_string(skip_eager_deletion_vars.size()) +
" vars in eager deletion mode: ";
for (auto &var : skip_eager_deletion_vars) {
debug_string.append(var);
debug_string.push_back(' ');
}
VLOG(10) << debug_string;
} }
auto ctx = auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
executor.Prepare(*program, block->ID(), skip_eager_deletion_vars);
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
...@@ -110,7 +115,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,7 +115,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<std::string>>("skip_eager_deletion_vars", AddAttr<std::vector<std::string>>(kSkipEagerDeletionVars,
"Vars that would skip eager deletion." "Vars that would skip eager deletion."
"Users should not set this manually.") "Users should not set this manually.")
.SetDefault(std::vector<std::string>()); .SetDefault(std::vector<std::string>());
...@@ -137,7 +142,12 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -137,7 +142,12 @@ class WhileGradOp : public framework::OperatorBase {
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto ctx = executor.Prepare(*program, block->ID());
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
if (framework::GetEagerDeletionThreshold() >= 0) {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
}
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
...@@ -359,29 +369,51 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -359,29 +369,51 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list); while_grad->SetAttr("original_output_grad", output_grads_list);
/* The following codes are used in eager deletion mode */ /* The followi_ng codes are used in eager deletion mode */
std::unordered_set<std::string> bwd_skip_vars;
if (framework::GetEagerDeletionThreshold() >= 0) { if (framework::GetEagerDeletionThreshold() >= 0) {
std::unordered_set<std::string> skip_vars; std::unordered_set<std::string> fwd_skip_vars;
for (auto *op_desc : grad_block->AllOps()) { for (auto *op_desc : grad_block->AllOps()) {
auto skippable = [&](const std::string &name) {
return !grad_block->HasVar(name) &&
(fwd_block->HasVarRecursive(name) ||
parent_block->HasVarRecursive(name));
};
for (auto &in_arg_name : op_desc->InputArgumentNames()) { for (auto &in_arg_name : op_desc->InputArgumentNames()) {
// If input var of ops inside grad_block is not from grad_block, if (skippable(in_arg_name)) {
// it cannot be deleted when forward while_op runs fwd_skip_vars.insert(in_arg_name);
if (in_arg_name != framework::kEmptyVarName && }
!grad_block->HasVar(in_arg_name)) { }
skip_vars.insert(in_arg_name);
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (skippable(out_arg_name)) {
fwd_skip_vars.insert(out_arg_name);
} }
} }
} }
if (!skip_vars.empty()) { if (!fwd_skip_vars.empty()) {
// FIXME(zjl): ugly const_cast here, maybe we should find a better way // FIXME(zjl): ugly const_cast here, maybe we should find a better way
// to modify forward while_op // to modify forward while_op
auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp()); auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp());
fwd_while_op.SetAttr( fwd_while_op.SetAttr(kSkipEagerDeletionVars,
"skip_eager_deletion_vars", std::vector<std::string>(fwd_skip_vars.begin(),
std::vector<std::string>(skip_vars.begin(), skip_vars.end())); fwd_skip_vars.end()));
}
// Find backward skip vars
auto fwd_input = Input(kX);
for (size_t i = 0; i < igs.size(); ++i) {
if (igs[i] == framework::kEmptyVarName) {
continue;
}
bwd_skip_vars.insert(igs[i]);
bwd_skip_vars.insert(framework::GradVarName(fwd_input[i]));
} }
} }
while_grad->SetAttr(
kSkipEagerDeletionVars,
std::vector<std::string>(bwd_skip_vars.begin(), bwd_skip_vars.end()));
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_mnist import TestMNIST
class EagerDeletionTestMNIST(TestMNIST):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_seresnext import TestResnet
class EagerDeletionTestSEResNext(TestResnet):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_transformer import TestTransformer
class EagerDeletionTestTransformer(TestTransformer):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册