未验证 提交 472f16b5 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #16063 from sneaxiy/enhance_gc

Enhance gc
...@@ -174,7 +174,7 @@ else() ...@@ -174,7 +174,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector) target_link_libraries(executor garbage_collector while_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
......
...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_ ...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx); size_t scope_idx);
OperatorBase *GetOp() { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
} }
} }
#endif #endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
} }
EagerDeletionOpHandle::~EagerDeletionOpHandle() { EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
auto it = ref_cnts_->find(name); auto it = ref_cnts_->find(name);
// Var not found, not reference count has not decreased to 0 // Reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) { if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue; continue;
} }
if (!exec_scope) {
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name); auto *var = exec_scope->FindVar(name);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
......
...@@ -12,20 +12,173 @@ ...@@ -12,20 +12,173 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <functional>
#include <queue> #include <queue>
#include <string> #include <string>
#include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
// Check whether the variable is LoDTensor based on static VarDesc info
static bool IsLoDTensor(VarDesc *var) {
return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
}
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
PADDLE_ENFORCE(IsLoDTensor(var_desc));
auto dims = var_desc->GetShape();
return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
// SelectedRows, LoDTensorArray)
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
(*lod_tensors)[op_vars_pair.first].insert(var_name);
} else {
(*other_vars)[op_vars_pair.first].insert(var_name);
}
}
}
}
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
if (fraction_of_memory_size <= 0.0) return {};
/**
* Step 1: Split all variables into LoDTensor and Non-LoDTensor.
* We can only calculate memory size of LoDTensors
*/
OpToVarNameSetMap lod_tensors, other_vars;
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
// Perform complete gc when fraction_of_memory_size >= 1
if (fraction_of_memory_size >= 1.0) {
return delete_lod_tensor_only ? lod_tensors : m;
}
/**
* Step 2: build GCVarInfos, and calculate total memory sizes of each device
*/
// place -> variable info (name, memory size, place, scope_idx)
std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
// place -> total memory sizes
std::map<platform::Place, int64_t> place_to_size;
for (auto &op_vars_pair : lod_tensors) {
auto *op = op_vars_pair.first;
auto &var_names = op_vars_pair.second;
auto scope_idx = op->GetScopeIdx();
auto &place = places[scope_idx];
for (auto &var_name : var_names) {
auto var_size = GetMemorySize(vars[scope_idx], var_name);
GCVarInfo var_info(var_name, var_size, op, scope_idx);
place_to_size[place] += var_info.AbsMemorySize();
place_to_vars[place].emplace_back(std::move(var_info));
}
}
/**
* Step 3: sort GCVarInfos, and only delete the largest variables.
*/
OpToVarNameSetMap partial_vars;
for (auto &place_to_var_pair : place_to_vars) {
auto &place = place_to_var_pair.first;
auto &gc_vars = place_to_var_pair.second;
std::sort(gc_vars.begin(), gc_vars.end(),
[](const GCVarInfo &var1, const GCVarInfo &var2) {
return var1.AbsMemorySize() > var2.AbsMemorySize();
});
int64_t accumulated_size = 0;
int64_t size_threshold =
static_cast<int64_t>(fraction_of_memory_size * place_to_size[place]);
for (size_t i = 0; i < gc_vars.size() && accumulated_size < size_threshold;
++i) {
partial_vars[gc_vars[i].op_].insert(gc_vars[i].name_);
accumulated_size += gc_vars[i].AbsMemorySize();
}
}
/**
* Step 4: Combine other vars (SelectedRows, LoDTensorArray)
*/
if (!delete_lod_tensor_only) {
for (auto &op_vars_pair : other_vars) {
partial_vars[op_vars_pair.first].insert(op_vars_pair.second.begin(),
op_vars_pair.second.end());
}
}
return partial_vars;
}
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
// a reverse map of last_live_ops // a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted. // i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>> OpToVarNameSetMap op_vars_map;
op_vars_map;
for (auto &var_ops_map : last_live_ops) { for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) { for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first; const std::string &var_name = var_ops_pair.first;
...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
} }
} }
op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
FLAGS_memory_fraction_of_eager_deletion);
for (auto &pair : op_vars_map) { for (auto &pair : op_vars_map) {
auto *op = pair.first; auto *op = pair.first;
auto &var_names = pair.second; auto &var_names = pair.second;
...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
<< FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
} }
} // namespace details } // namespace details
...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass, ...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector); .RequirePassAttr(paddle::framework::details::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op); return shrink_func(computation_op);
} }
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
......
...@@ -13,9 +13,22 @@ ...@@ -13,9 +13,22 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {} // namespace details namespace details {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <atomic> #include <atomic>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -25,6 +26,10 @@ ...@@ -25,6 +26,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class VarDesc;
class VarHandle;
namespace details { namespace details {
class ComputationOpHandle; class ComputationOpHandle;
...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector"; ...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars = using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>; std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle {
namespace framework {
namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
compute_op->GetOp());
} else if (compute_op->Name() == "while_grad") {
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
compute_op->GetOp());
}
}
for (auto &ops_pair : target_ops) {
auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque> #include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
...@@ -23,6 +27,7 @@ limitations under the License. */ ...@@ -23,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts( ...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id, const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars) const std::vector<std::string>& keep_vars, bool force_disable_gc)
: prog_(prog), block_id_(block_id) { : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
if (GetEagerDeletionThreshold() >= 0) { if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), global_ref_cnts_ =
skip_ref_cnt_vars); GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
} }
} }
...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc); if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif #endif
auto ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) { const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx( std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars)); program, block_id, skip_ref_cnt_vars, force_disable_gc));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) { const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
bool force_disable_gc) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d", "skip_ref_cnt_vars should be either empty or equals to block number %d",
...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
for (auto& bid : block_ids) { for (auto& bid : block_ids) {
ExecutorPrepareContext* ctx; ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) { if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid); ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
force_disable_gc);
} else { } else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]); ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
force_disable_gc);
} }
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid); auto& block = program.Block(bid);
...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily // FIXME(zjl): recurrent_op is rather complex, we would
if (max_memory_size >= 0 && !keep_kids) { // disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -30,7 +32,8 @@ namespace framework { ...@@ -30,7 +32,8 @@ namespace framework {
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext { ...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
bool force_disable_gc_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> global_ref_cnts_; std::unordered_map<std::string, size_t> global_ref_cnts_;
...@@ -66,7 +70,10 @@ class Executor { ...@@ -66,7 +70,10 @@ class Executor {
* Scope * Scope
*/ */
void Run(const ProgramDesc& prog, Scope* scope, int block_id, void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true); bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
// This API is very slow. // This API is very slow.
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
...@@ -79,12 +86,14 @@ class Executor { ...@@ -79,12 +86,14 @@ class Executor {
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare( static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars = const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>()); std::vector<std::vector<std::string>>(),
bool force_disable_gc = false);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
...@@ -26,14 +27,6 @@ namespace operators { ...@@ -26,14 +27,6 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString( static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) { const std::vector<std::string> &vars) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace operators {
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
struct InputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Inputs());
}
};
struct OutputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Outputs());
}
};
struct AttributeMapVisitor
: public boost::static_visitor<const framework::AttributeMap *> {
const framework::AttributeMap *operator()(
const framework::OpDesc *op) const {
return &(op->GetAttrMap());
}
const framework::AttributeMap *operator()(
const framework::OperatorBase *op) const {
return &(op->Attrs());
}
};
struct RawPointerVisitor : public boost::static_visitor<const void *> {
template <typename OpType>
const void *operator()(const OpType *op) const {
return op;
}
};
public:
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
const framework::VariableNameMap &Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
}
const framework::VariableNameMap &Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
}
const framework::AttributeMap &Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
}
template <typename AttrType>
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
return boost::get<AttrType>(it->second);
}
bool operator==(const OpVariant &other) const {
return RawPointer() == other.RawPointer();
}
const void *RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
}
int which() const { return static_cast<int>(op_.which()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
return reinterpret_cast<size_t>(op.RawPointer());
}
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
static std::string GetDebugString(const std::vector<std::string> &names) {
if (names.empty()) return "";
std::string ret = names[0];
for (size_t i = 1; i < names.size(); ++i) {
ret += (" " + names[i]);
}
return ret;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
VLOG(2) << "Prepare to skip " << attr.size()
<< " var(s): " << GetDebugString(attr);
attrs[kSkipEagerDeletionVars] = std::move(attr);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
const OpVariant &grad_op) {
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
framework::BlockDesc *grad_block) {
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
const OpVariant &bwd_op) {
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
// Find all skippable variables in forward while_op
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (IsSkippableVar(in_arg_name, grad_block)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (IsSkippableVar(out_arg_name, grad_block)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
forward_skip_vars.end()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto &fwd_input = fwd_op.Inputs().at(kX);
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
std::unordered_set<std::string> backward_skip_vars;
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
}
backward_skip_vars.insert(in_grads[i]);
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
backward_skip_vars.end()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
while_ops->emplace_back(op);
} else if (op->Type() == "while_grad") {
while_grad_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
"There are extra while_grad ops in the graph or program");
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
if (while_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
while_ops->begin(), while_ops->end());
for (auto &bwd_op : *while_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : while_op_set) {
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched while ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward while op.");
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op);
}
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "while") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "while_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(while_ops.size());
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,30 @@ ...@@ -14,19 +14,30 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace details {
class EagerDeletionPass : public ir::Pass { static constexpr char kStepBlock[] = "sub_block";
protected: static constexpr char kCondition[] = "Condition";
std::unique_ptr<ir::Graph> ApplyImpl( static constexpr char kStepScopes[] = "StepScopes";
std::unique_ptr<ir::Graph> graph) const override; static constexpr char kX[] = "X";
}; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
} // namespace details void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
} // namespace framework int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute! // Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(5) << "Recurrent memory linking finished "; VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope // Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
......
...@@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
self.Run(prog, scope, block_id, create_local_scope, create_vars); self.Run(prog, scope, block_id, create_local_scope, create_vars,
fetch_vars);
}); });
m.def("init_gflags", framework::InitGflags); m.def("init_gflags", framework::InitGflags);
......
...@@ -128,11 +128,11 @@ def __bootstrap__(): ...@@ -128,11 +128,11 @@ def __bootstrap__():
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph', 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'allocator_strategy', 'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
'reader_queue_speed_test_mode', 'print_sub_graph_dir', 'allocator_strategy', 'reader_queue_speed_test_mode',
'pe_profile_fname', 'warpctc_dir', 'inner_op_parallelism', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'enable_parallel_graph', 'multiple_of_cupti_buffer_size', 'inner_op_parallelism', 'enable_parallel_graph',
'enable_subgraph_optimize' 'multiple_of_cupti_buffer_size', 'enable_subgraph_optimize'
] ]
if 'Darwin' not in sysstr: if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory') read_env_flags.append('use_pinned_memory')
......
...@@ -590,7 +590,7 @@ class Executor(object): ...@@ -590,7 +590,7 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
exe.run(program.desc, scope, 0, True, True) exe.run(program.desc, scope, 0, True, True, fetch_var_name)
outs = self._fetch_data(fetch_list, fetch_var_name, scope) outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
......
...@@ -16,8 +16,7 @@ import os ...@@ -16,8 +16,7 @@ import os
import unittest import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ[ os.environ['RECORDIO_FILENAME'] = './eager_deletion_transformer.wmt16.recordio'
'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer from test_parallel_executor_transformer import TestTransformer
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['CPU_NUM'] = '2'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['FLAGS_fast_eager_deletion_mode'] = '1'
import unittest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
from paddle.fluid.backward import append_backward
import paddle.fluid.compiler as compiler
import numpy
import multiprocessing
class TestEagerDeletionWhileOpBase(unittest.TestCase):
def test_main(self):
places = [core.CPUPlace(), ]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for p in places:
for with_data_parallel in [False, True]:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_main(p, with_data_parallel)
def run_main(self, place, with_data_parallel):
self.place = place
self.with_data_parallel = with_data_parallel
if not core.is_compiled_with_cuda() and isinstance(self.place,
core.CUDAPlace):
return
if isinstance(self.place, core.CUDAPlace):
device_cnt = core.get_cuda_device_count(
) if self.with_data_parallel else 1
else:
device_cnt = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count(
))) if self.with_data_parallel else 1
d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32')
d1 = layers.data(
"d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i)
layers.array_write(d1, i, array=data_array)
i = layers.increment(i)
layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
d = layers.reshape(d, shape=[10])
prev = layers.reshape(prev, shape=[10])
result = layers.sums(input=[d, prev])
i = layers.increment(x=i, in_place=True)
layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond)
with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
d2 = layers.reshape(d2, shape=[10])
prev2 = layers.reshape(prev2, shape=[10])
result2 = layers.sums(input=[d2, prev2])
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
sum_result.persistable = True
tmp = layers.unsqueeze(sum_result, axes=[0])
tmp = layers.expand(tmp, expand_times=[10, 1])
fc = layers.fc(tmp, size=256)
loss = layers.mean(sum_result)
optim = fluid.optimizer.Adam(learning_rate=1e-3)
optim.minimize(loss)
exe = Executor(self.place)
exe.run(fluid.default_startup_program())
prog = compiler.CompiledProgram(fluid.default_main_program())
if self.with_data_parallel:
prog = prog.with_data_parallel()
for _ in range(5):
d = []
for i in range(3):
tmp = numpy.random.random(size=[10]).astype('float32')
if not self.with_data_parallel:
d.append(tmp)
else:
d.append(numpy.array([tmp] * device_cnt))
outs = exe.run(program=prog,
feed={'d0': d[0],
'd1': d[1],
'd2': d[2]},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ['FLAGS_memory_fraction_of_eager_deletion'] = "0.55"
os.environ['RECORDIO_FILENAME'] = './p_gc_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册