未验证 提交 7ac748ad 编写于 作者: Z Zeng Jinle 提交者: GitHub

Open gc by default (#18836)

* open gc by default, test=develop

* fix test_train_recognize_digits and disable gc when ngraph is enabled, test=develop

* fix conditional_block op eager deletion bug, test=develop

* add some comments to reviewers, test=develop
上级 3816d221
...@@ -193,7 +193,7 @@ else() ...@@ -193,7 +193,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper) target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper conditional_block_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
......
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_factory.h" #include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
...@@ -58,10 +59,30 @@ ExecutorPrepareContext::ExecutorPrepareContext( ...@@ -58,10 +59,30 @@ ExecutorPrepareContext::ExecutorPrepareContext(
void ExecutorPrepareContext::PrepareUnusedVars( void ExecutorPrepareContext::PrepareUnusedVars(
const std::vector<std::string>& keep_vars, bool force_disable_gc) { const std::vector<std::string>& keep_vars, bool force_disable_gc) {
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
// FIXME(zjl): There is difference when ngraph and gc are both enabled
// in unittests. I do not know why it happens. Maybe ngraph engine
// would cache some variables?
LOG_FIRST_N(WARNING, 1)
<< "FLAGS_use_ngraph=True, garbage collection strategy is "
"disabled in Executor";
force_disable_gc = true;
}
#endif
force_disable_gc_ = force_disable_gc; force_disable_gc_ = force_disable_gc;
if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) { if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
return; return;
} }
// If gc is enabled and block size > 1
if (prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
block_id_, ops_);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
block_id_, ops_);
}
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
} }
...@@ -407,13 +428,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -407,13 +428,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
ctx->block_id_, ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
...@@ -28,8 +28,15 @@ ...@@ -28,8 +28,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Disable gc by default when inference library is built
#ifdef PADDLE_ON_INFERENCE
static const double kDefaultEagerDeleteTensorGB = -1;
#else
static const double kDefaultEagerDeleteTensorGB = 0;
#endif
DEFINE_double( DEFINE_double(
eager_delete_tensor_gb, -1.0, eager_delete_tensor_gb, kDefaultEagerDeleteTensorGB,
"Memory size threshold (GB) when the garbage collector clear tensors." "Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"); "Disabled when this value is less than 0");
...@@ -48,6 +55,9 @@ GarbageCollector::GarbageCollector(const platform::Place &place, ...@@ -48,6 +55,9 @@ GarbageCollector::GarbageCollector(const platform::Place &place,
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) { : max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue()); garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
if (max_memory_size_ > 1) {
mutex_.reset(new std::mutex());
}
} }
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place, CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
......
...@@ -46,7 +46,7 @@ class GarbageCollector { ...@@ -46,7 +46,7 @@ class GarbageCollector {
platform::DeviceContext *dev_ctx_; platform::DeviceContext *dev_ctx_;
std::unique_ptr<GarbageQueue> garbages_; std::unique_ptr<GarbageQueue> garbages_;
mutable std::mutex mutex_; mutable std::unique_ptr<std::mutex> mutex_;
const size_t max_memory_size_; const size_t max_memory_size_;
size_t cur_memory_size_{0}; size_t cur_memory_size_{0};
}; };
...@@ -118,7 +118,7 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) { ...@@ -118,7 +118,7 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) {
GarbageQueue *garbage_queue = nullptr; GarbageQueue *garbage_queue = nullptr;
{ {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(*mutex_);
for (auto &obj : objs) { for (auto &obj : objs) {
if (!obj) continue; if (!obj) continue;
cur_memory_size_ += obj->size(); cur_memory_size_ += obj->size();
......
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass) cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"
namespace paddle {
namespace framework {
namespace ir {
class ConditionalOpEagerDeletionPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all conditional_op and conditional_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "conditional_block") {
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
compute_op->GetOp());
} else if (compute_op->Name() == "conditional_block_grad") {
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
compute_op->GetOp());
}
}
for (auto &ops_pair : target_ops) {
auto &ifelse_ops = ops_pair.second.first;
auto &ifelse_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
ifelse_ops, ifelse_grad_ops);
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conditional_block_op_eager_deletion_pass,
paddle::framework::ir::ConditionalOpEagerDeletionPass);
...@@ -269,6 +269,11 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -269,6 +269,11 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} }
} }
auto conditional_block_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get(
"conditional_block_op_eager_deletion_pass");
conditional_block_op_eager_deletion_pass->Apply(graph);
auto while_op_eager_deletion_pass = auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
while_op_eager_deletion_pass->Apply(graph); while_op_eager_deletion_pass->Apply(graph);
...@@ -288,5 +293,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) ...@@ -288,5 +293,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kAllPlaces) .RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector); .RequirePassAttr(paddle::framework::ir::kGarbageCollector);
USE_PASS(conditional_block_op_eager_deletion_pass);
USE_PASS(while_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass);
USE_PASS(recurrent_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass);
...@@ -337,6 +337,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -337,6 +337,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) { ++iter) {
if ((*iter)->Node()->IsCtrlVar()) {
break;
}
VLOG(10) << "Try to find last living ops of " << var_name << " " VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time"; << (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
......
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc) cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc)
cc_library(conditional_block_op_helper SRCS conditional_block_op_helper.cc DEPS operator op_variant conditional_block_op)
cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant recurrent_op) cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant recurrent_op)
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant) cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant)
target_link_libraries(conditional_block_infer_op conditional_block_op)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -17,6 +17,12 @@ limitations under the License. */ ...@@ -17,6 +17,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
const char ConditionalOp::kInputs[] = "Input";
const char ConditionalOp::kOutputs[] = "Out";
const char ConditionalOp::kCondition[] = "Cond";
const char ConditionalOp::kScope[] = "Scope";
const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
class ConditionalBlockOp : public ConditionalOp { class ConditionalBlockOp : public ConditionalOp {
public: public:
ConditionalBlockOp(const std::string &type, ConditionalBlockOp(const std::string &type,
...@@ -33,20 +39,20 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -33,20 +39,20 @@ class ConditionalBlockOp : public ConditionalOp {
// When is_scalar_condition is True, the conditional variable is a scalar, // When is_scalar_condition is True, the conditional variable is a scalar,
// whether need to execute the operators in sub-block depends on the // whether need to execute the operators in sub-block depends on the
// conditional variable (Cond). // conditional variable (Cond).
auto xs = InputTensors(scope, "Cond"); auto xs = InputTensors(scope, ConditionalOp::kCondition);
need_run = ScalarCondition(xs); need_run = ScalarCondition(xs);
} else { } else {
// When is_scalar_condition is False, the conditional variable maybe a // When is_scalar_condition is False, the conditional variable maybe a
// vector or tensor, whether need to execute the operators in sub-block // vector or tensor, whether need to execute the operators in sub-block
// depends on the input variables (Input). // depends on the input variables (Input).
auto xs = InputTensors(scope, "Input"); auto xs = InputTensors(scope, ConditionalOp::kInputs);
need_run = std::all_of( need_run = std::all_of(
xs.begin(), xs.end(), xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; }); [](const framework::LoDTensor *t) { return t->numel() != 0; });
} }
if (need_run) { if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope")); auto *scope_var = scope.FindVar(Output(ConditionalOp::kScope));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>(); auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
scopes->resize(1); scopes->resize(1);
...@@ -55,7 +61,10 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -55,7 +61,10 @@ class ConditionalBlockOp : public ConditionalOp {
framework::Executor exec(dev_place); framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
exec.Run(*block->Program(), &cur_scope, block->ID(), false); auto &skip_vars =
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
skip_vars);
} }
} }
}; };
...@@ -73,17 +82,17 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -73,17 +82,17 @@ class ConditionalBlockGradOp : public ConditionalOp {
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
bool need_run; bool need_run;
if (Attr<bool>("is_scalar_condition")) { if (Attr<bool>("is_scalar_condition")) {
auto xs = this->InputTensors(scope, "Cond"); auto xs = this->InputTensors(scope, ConditionalOp::kCondition);
need_run = ScalarCondition(xs); need_run = ScalarCondition(xs);
} else { } else {
auto xs = this->InputTensors(scope, "Input"); auto xs = this->InputTensors(scope, ConditionalOp::kInputs);
need_run = std::all_of( need_run = std::all_of(
xs.begin(), xs.end(), xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; }); [](const framework::LoDTensor *t) { return t->numel() != 0; });
} }
if (need_run) { if (need_run) {
auto *scope_var = scope.FindVar(Input("Scope")); auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>(); auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
...@@ -91,10 +100,12 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -91,10 +100,12 @@ class ConditionalBlockGradOp : public ConditionalOp {
framework::Executor exec(dev_place); framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
const auto &ins = Inputs("Input"); const auto &ins = Inputs(ConditionalOp::kInputs);
const auto &d_ins = Outputs(framework::GradVarName("Input")); const auto &d_ins =
const auto &conds = Inputs("Cond"); Outputs(framework::GradVarName(ConditionalOp::kInputs));
const auto &d_conds = Outputs(framework::GradVarName("Cond")); const auto &conds = Inputs(ConditionalOp::kCondition);
const auto &d_conds =
Outputs(framework::GradVarName(ConditionalOp::kCondition));
std::vector<std::string> ins_conds_grads; std::vector<std::string> ins_conds_grads;
ins_conds_grads.reserve(ins.size() + conds.size()); ins_conds_grads.reserve(ins.size() + conds.size());
...@@ -142,15 +153,17 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -142,15 +153,17 @@ class ConditionalBlockGradOp : public ConditionalOp {
class ConditionalBlockGradInferShape : public framework::InferShapeBase { class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("Cond")); PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition));
if (context->HasInputs("Input")) { if (context->HasInputs(ConditionalOp::kInputs)) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input"))); PADDLE_ENFORCE(
context->SetOutputsDim(framework::GradVarName("Input"), context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs)));
context->GetInputsDim("Input")); context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs),
context->GetInputsDim(ConditionalOp::kInputs));
} }
if (context->HasOutputs(framework::GradVarName("Cond"))) { if (context->HasOutputs(
context->SetOutputsDim(framework::GradVarName("Cond"), framework::GradVarName(ConditionalOp::kCondition))) {
context->GetInputsDim("Cond")); context->SetOutputsDim(framework::GradVarName(ConditionalOp::kCondition),
context->GetInputsDim(ConditionalOp::kCondition));
} }
} }
}; };
...@@ -163,15 +176,17 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -163,15 +176,17 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad_op = new framework::OpDesc(); auto grad_op = new framework::OpDesc();
grad_op->SetType("conditional_block_grad"); grad_op->SetType("conditional_block_grad");
grad_op->SetInput("Cond", Input("Cond")); grad_op->SetInput(ConditionalOp::kCondition,
grad_op->SetInput("Input", Input("Input")); Input(ConditionalOp::kCondition));
grad_op->SetInput("Out", Output("Out")); grad_op->SetInput(ConditionalOp::kInputs, Input(ConditionalOp::kInputs));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(ConditionalOp::kOutputs, Output(ConditionalOp::kOutputs));
grad_op->SetInput("Scope", Output("Scope")); grad_op->SetInput(framework::GradVarName(ConditionalOp::kOutputs),
grad_op->SetOutput(framework::GradVarName("Cond"), OutputGrad(ConditionalOp::kOutputs));
InputGrad("Cond", false)); grad_op->SetInput(ConditionalOp::kScope, Output(ConditionalOp::kScope));
grad_op->SetOutput(framework::GradVarName("Input"), grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition),
InputGrad("Input", false)); InputGrad(ConditionalOp::kCondition, false));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs),
InputGrad(ConditionalOp::kInputs, false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition")); grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
......
...@@ -33,6 +33,12 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -33,6 +33,12 @@ class ConditionalOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
static const char kInputs[];
static const char kOutputs[];
static const char kCondition[];
static const char kScope[];
static const char kSkipEagerDeletionVars[];
protected: protected:
std::vector<const framework::LoDTensor *> InputTensors( std::vector<const framework::LoDTensor *> InputTensors(
const framework::Scope &scope, const std::string &in_name) const { const framework::Scope &scope, const std::string &in_name) const {
...@@ -78,13 +84,15 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -78,13 +84,15 @@ class ConditionalOp : public framework::OperatorBase {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Cond", AddInput(ConditionalOp::kCondition,
"The conditional variable of this operator. If Cond is empty, the " "The conditional variable of this operator. If Cond is empty, the "
"whole sub-block will not be executed.") "whole sub-block will not be executed.")
.AsDuplicable(); .AsDuplicable();
AddInput("Input", "The input variables of the sub-block.").AsDuplicable(); AddInput(ConditionalOp::kInputs, "The input variables of the sub-block.")
AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); .AsDuplicable();
AddOutput("Scope", AddOutput(ConditionalOp::kOutputs, "The output variables of the sub-block.")
.AsDuplicable();
AddOutput(ConditionalOp::kScope,
"(std::vector<Scope*>) The step scope of conditional block. To " "(std::vector<Scope*>) The step scope of conditional block. To "
"unify the conditional block, rnn and while op, the type of " "unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>"); "scope is std::vector<Scope*>");
...@@ -94,6 +102,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -94,6 +102,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"The conditional variable (Cond) is used as scalar " "The conditional variable (Cond) is used as scalar "
"condition.") "condition.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars,
"Vars that would not be deleted when "
"garbage collection strategy enables")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC(Conditional block operator AddComment(R"DOC(Conditional block operator
If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar, If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/controlflow/op_variant.h"
namespace paddle {
namespace operators {
static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
const OpVariant &fwd_op, const OpVariant &bwd_op) {
return fwd_op.Outputs().at(ConditionalOp::kScope) ==
bwd_op.Inputs().at(ConditionalOp::kScope);
}
static void FindAllConditionalBlockAndConditionalBlockGradOp(
std::vector<OpVariant> *fwd_ops, std::vector<OpVariant> *bwd_ops) {
PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size());
if (fwd_ops->empty()) return;
const auto *program =
fwd_ops->front().Attr<framework::BlockDesc *>("sub_block")->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "conditional_block") {
fwd_ops->emplace_back(op);
} else if (op->Type() == "conditional_block_grad") {
bwd_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(
fwd_ops->size(), bwd_ops->size(),
"There are extra conditional_block_grad ops in the graph or program");
}
static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
OpVariant *bwd_op) {
auto *grad_block = bwd_op->Attr<framework::BlockDesc *>("sub_block");
auto is_skippable_in_fwd = [grad_block](const std::string &var_name) {
return var_name != framework::kEmptyVarName &&
!grad_block->HasVar(var_name);
};
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (is_skippable_in_fwd(in_arg_name)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (is_skippable_in_fwd(out_arg_name)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
auto &fwd_attrs = const_cast<framework::AttributeMap &>(fwd_op->Attrs());
std::vector<std::string> skip_vars_vec(forward_skip_vars.begin(),
forward_skip_vars.end());
VLOG(2) << "Prepare to skip " << skip_vars_vec.size()
<< " var(s): " << string::join_strings(skip_vars_vec, ' ');
fwd_attrs[ConditionalOp::kSkipEagerDeletionVars] = std::move(skip_vars_vec);
}
static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
std::vector<OpVariant> *ifelse_ops,
std::vector<OpVariant> *ifelse_grad_ops) {
FindAllConditionalBlockAndConditionalBlockGradOp(ifelse_ops, ifelse_grad_ops);
VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size()
<< ", conditional_block_grad op num: " << ifelse_grad_ops->size();
if (ifelse_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> ifelse_op_set(
ifelse_ops->begin(), ifelse_ops->end());
for (auto &bwd_op : *ifelse_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : ifelse_op_set) {
if (IsMatchedConditionalBlockOpAndConditionalBlockGradOp(fwd_op,
bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched conditional_block ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward conditional_block op");
SetSkipVarsForConditionalBlockOp(const_cast<OpVariant *>(matched_fwd_op),
&bwd_op);
ifelse_op_set.erase(*matched_fwd_op);
}
}
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all conditional_block_ops and conditional_block_grad_ops
// in the whole program would be processed when block_id is 0 (i.e.
// when Executor::Run() or ParallelExecutor constructs).
// What's more, all conditional_block_ops and conditional_block_grad_ops
// must be processed when block_id is zero. If not, conditional_block_op
// may run first and erase variables used in conditional_block_grad_op,
// and in this moment, conditional_block_grad_ops may be not constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "conditional_block") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "conditional_block_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops,
&bwd_ops);
}
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(ifelse_ops.size());
for (auto *op : ifelse_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(ifelse_grad_ops.size());
for (auto *op : ifelse_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops,
&bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op.h"
namespace paddle {
namespace operators {
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const std::vector<framework::OperatorBase *> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops);
} // namespace operators
} // namespace paddle
...@@ -167,7 +167,7 @@ void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { ...@@ -167,7 +167,7 @@ void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
tensor->set_lod(underlying_outs[i].lod()); tensor->set_lod(underlying_outs[i].lod());
} }
// 2. Run the sub-block. // 2. Run the sub-block.
exe_.Run(program_, exe_scope, sub_block_id_, false, true); exe_.Run(program_, exe_scope, sub_block_id_, false, true, {}, true);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
......
...@@ -74,7 +74,8 @@ void Train() { ...@@ -74,7 +74,8 @@ void Train() {
float first_loss = 0.0; float first_loss = 0.0;
float last_loss = 0.0; float last_loss = 0.0;
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
executor.Run(*train_program, &scope, 0, false, true); executor.Run(*train_program, &scope, 0, false, true,
{loss_name, "img", "label"});
if (i == 0) { if (i == 0) {
first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0]; first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
} else if (i == 99) { } else if (i == 99) {
......
...@@ -112,6 +112,7 @@ class TestSendOp(unittest.TestCase): ...@@ -112,6 +112,7 @@ class TestSendOp(unittest.TestCase):
dtype='float32', dtype='float32',
name='X', name='X',
append_batch_size=False) append_batch_size=False)
x.persistable = True
fluid.initializer.Constant(value=2.3)(x, main.global_block()) fluid.initializer.Constant(value=2.3)(x, main.global_block())
get_var = main.global_block().create_var( get_var = main.global_block().create_var(
...@@ -121,6 +122,13 @@ class TestSendOp(unittest.TestCase): ...@@ -121,6 +122,13 @@ class TestSendOp(unittest.TestCase):
shape=[32, 32]) shape=[32, 32])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
# NOTE(zjl): `Send` is async send, which means that the sent
# variable would be needed even though `Send` op runs.
# Is it a right design? If I do not set `x.persistable = True`,
# this unittest would hang in rpc client after x is deleted.
#
# BTW, `Send` is not a public API to users. So I set
# `x.persistable = True` to be a hot fix of this unittest.
Send("127.0.0.1:%d" % port, [x]) Send("127.0.0.1:%d" % port, [x])
o = Recv("127.0.0.1:%d" % port, [get_var]) o = Recv("127.0.0.1:%d" % port, [get_var])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册