提交 597dc65e 编写于 作者: S sneaxiy

enhance gc

test=develop
上级 a9ea99d7
...@@ -174,7 +174,7 @@ else() ...@@ -174,7 +174,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector) target_link_libraries(executor garbage_collector while_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
......
...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_ ...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
......
...@@ -31,6 +31,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -31,6 +31,8 @@ class ComputationOpHandle : public OpHandleBase {
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx); size_t scope_idx);
OperatorBase *GetOp() { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -25,8 +25,6 @@ namespace paddle { ...@@ -25,8 +25,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static const std::string kEagerDeletionOpName{"eager_deletion"}; // NOLINT
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place, ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc, const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
...@@ -61,10 +59,9 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -61,10 +59,9 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#endif #endif
} }
std::string EagerDeletionOpHandle::Name() const { return kEagerDeletionOpName; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
platform::RecordEvent event(kEagerDeletionOpName, nullptr);
Scope *exec_scope = nullptr; Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
......
...@@ -21,35 +21,42 @@ ...@@ -21,35 +21,42 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double(fraction_of_eager_deletion, 1.0, "Fraction of eager deletion"); DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
DEFINE_bool(eager_delete_tensor_only, false, ""); "Fraction of eager deletion");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace { // NOLINT // op -> variables which can be deleted after op runs
using OpToVarNameSetMap = using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>; std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
} // NOLINT
// Check whether the variable is LoDTensor based on static VarDesc info
static bool IsLoDTensor(VarDesc *var) { static bool IsLoDTensor(VarDesc *var) {
return var->Proto()->type().type() == proto::VarType::LOD_TENSOR; return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
} }
static int64_t GetNumel(const GraphVars &vars, const std::string &var_name, // Get memory size of LoDTensor
size_t scope_idx) { static int64_t GetMemorySize(
auto *var_desc = TryGetLatestVarDesc(vars[scope_idx].at(var_name)); const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
PADDLE_ENFORCE(IsLoDTensor(var_desc)); PADDLE_ENFORCE(IsLoDTensor(var_desc));
auto dims = var_desc->GetShape(); auto dims = var_desc->GetShape();
return std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1), return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>()); std::multiplies<int64_t>());
} }
// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
// SelectedRows, LoDTensorArray)
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars( static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars, const OpToVarNameSetMap &m, const GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) { OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
...@@ -69,76 +76,106 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars( ...@@ -69,76 +76,106 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
} }
} }
static OpToVarNameSetMap ShrinkGCVars(const OpToVarNameSetMap &m, struct GCVarInfo {
const GraphVars &vars, GCVarInfo(const std::string &name, int64_t memory_size,
double fraction_of_memory_size, ComputationOpHandle *op, size_t scope_idx)
bool delete_lod_tensor_only = false) { : name_(name),
// Do not perform gc memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
if (fraction_of_memory_size <= 0.0) return {}; if (fraction_of_memory_size <= 0.0) return {};
// Perform complete gc /**
* Step 1: Split all variables into LoDTensor and Non-LoDTensor.
* We can only calculate memory size of LoDTensors
*/
OpToVarNameSetMap lod_tensors, other_vars;
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
// Perform complete gc when fraction_of_memory_size >= 1
if (fraction_of_memory_size >= 1.0) { if (fraction_of_memory_size >= 1.0) {
if (delete_lod_tensor_only) { return delete_lod_tensor_only ? lod_tensors : m;
OpToVarNameSetMap lod_tensors, other_vars;
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
return lod_tensors;
} else {
return m;
}
} }
// Perform partial gc /**
OpToVarNameSetMap lod_tensors, other_vars; * Step 2: build GCVarInfos, and calculate total memory sizes of each device
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars); */
using TupleType = std::tuple<std::string, ComputationOpHandle *, int64_t>; // place -> variable info (name, memory size, place, scope_idx)
std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
std::unordered_map<size_t, std::vector<TupleType>> place_to_vars; // place -> total memory sizes
std::unordered_map<size_t, int64_t> total_memory_size; std::map<platform::Place, int64_t> place_to_size;
for (auto &op_vars_pair : lod_tensors) { for (auto &op_vars_pair : lod_tensors) {
auto scope_idx = op_vars_pair.first->GetScopeIdx(); auto *op = op_vars_pair.first;
int64_t size = 0; auto &var_names = op_vars_pair.second;
for (auto &var_name : op_vars_pair.second) { auto scope_idx = op->GetScopeIdx();
auto var_size = GetNumel(vars, var_name, scope_idx); auto &place = places[scope_idx];
size += std::abs(var_size);
place_to_vars[scope_idx].emplace_back(var_name, op_vars_pair.first, for (auto &var_name : var_names) {
var_size); auto var_size = GetMemorySize(vars[scope_idx], var_name);
GCVarInfo var_info(var_name, var_size, op, scope_idx);
place_to_size[place] += var_info.AbsMemorySize();
place_to_vars[place].emplace_back(std::move(var_info));
} }
total_memory_size.emplace(scope_idx, size);
} }
for (auto &pair : place_to_vars) { /**
std::sort(pair.second.begin(), pair.second.end(), * Step 3: sort GCVarInfos, and only delete the largest variables.
[](const TupleType &t1, const TupleType &t2) { */
return std::abs(std::get<2>(t1)) > std::abs(std::get<2>(t2)); OpToVarNameSetMap partial_vars;
for (auto &place_to_var_pair : place_to_vars) {
auto &place = place_to_var_pair.first;
auto &gc_vars = place_to_var_pair.second;
std::sort(gc_vars.begin(), gc_vars.end(),
[](const GCVarInfo &var1, const GCVarInfo &var2) {
return var1.AbsMemorySize() > var2.AbsMemorySize();
}); });
}
OpToVarNameSetMap ret; int64_t accumulated_size = 0;
for (auto &pair : place_to_vars) { int64_t size_threshold =
auto desired_delete_size = static_cast<int64_t>( static_cast<int64_t>(fraction_of_memory_size * place_to_size[place]);
fraction_of_memory_size * total_memory_size.at(pair.first)); for (size_t i = 0; i < gc_vars.size() && accumulated_size < size_threshold;
int64_t cur_size = 0;
for (size_t i = 0; i < pair.second.size() && cur_size < desired_delete_size;
++i) { ++i) {
auto &var_name = std::get<0>(pair.second[i]); partial_vars[gc_vars[i].op_].insert(gc_vars[i].name_);
auto *op = std::get<1>(pair.second[i]); accumulated_size += gc_vars[i].AbsMemorySize();
cur_size += std::get<2>(pair.second[i]);
ret[op].insert(var_name);
} }
} }
/**
* Step 4: Combine other vars (SelectedRows, LoDTensorArray)
*/
if (!delete_lod_tensor_only) { if (!delete_lod_tensor_only) {
for (auto &op_vars_pair : other_vars) { for (auto &op_vars_pair : other_vars) {
for (auto &var_name : op_vars_pair.second) { partial_vars[op_vars_pair.first].insert(op_vars_pair.second.begin(),
ret[op_vars_pair.first].insert(var_name); op_vars_pair.second.end());
}
} }
} }
return ret; return partial_vars;
} }
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
...@@ -166,9 +203,8 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -166,9 +203,8 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
} }
} }
op_vars_map = op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
ShrinkGCVars(op_vars_map, vars, FLAGS_fraction_of_eager_deletion, FLAGS_memory_fraction_of_eager_deletion);
FLAGS_eager_delete_tensor_only);
for (auto &pair : op_vars_map) { for (auto &pair : op_vars_map) {
auto *op = pair.first; auto *op = pair.first;
...@@ -200,12 +236,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -200,12 +236,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "FLAGS_fraction_of_eager_deletion = " VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
<< FLAGS_fraction_of_eager_deletion; << FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "FLAGS_eager_delete_tensor_only = "
<< FLAGS_eager_delete_tensor_only;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
} }
} // namespace details } // namespace details
...@@ -218,3 +255,5 @@ REGISTER_PASS(eager_deletion_pass, ...@@ -218,3 +255,5 @@ REGISTER_PASS(eager_deletion_pass,
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector); .RequirePassAttr(paddle::framework::details::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle {
namespace framework {
namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
compute_op->GetOp());
} else if (compute_op->Name() == "while_grad") {
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
compute_op->GetOp());
}
}
for (auto &ops_pair : target_ops) {
auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -409,8 +410,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -409,8 +410,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily if (max_memory_size >= 0) {
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -428,6 +428,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -428,6 +428,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
if (gc) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
...@@ -26,14 +27,6 @@ namespace operators { ...@@ -26,14 +27,6 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString( static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) { const std::vector<std::string> &vars) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace operators {
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
struct InputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Inputs());
}
};
struct OutputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Outputs());
}
};
struct AttributeMapVisitor
: public boost::static_visitor<const framework::AttributeMap *> {
const framework::AttributeMap *operator()(
const framework::OpDesc *op) const {
return &(op->GetAttrMap());
}
const framework::AttributeMap *operator()(
const framework::OperatorBase *op) const {
return &(op->Attrs());
}
};
struct RawPointerVisitor : public boost::static_visitor<const void *> {
template <typename OpType>
const void *operator()(const OpType *op) const {
return op;
}
};
public:
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
const framework::VariableNameMap &Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
}
const framework::VariableNameMap &Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
}
const framework::AttributeMap &Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
}
template <typename AttrType>
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
return boost::get<AttrType>(it->second);
}
bool operator==(const OpVariant &other) const {
return RawPointer() == other.RawPointer();
}
const void *RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
}
int which() const { return static_cast<int>(op_.which()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
return reinterpret_cast<size_t>(op.RawPointer());
}
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
static std::string GetDebugString(const std::vector<std::string> &names) {
if (names.empty()) return "";
std::string ret = names[0];
for (size_t i = 1; i < names.size(); ++i) {
ret += (" " + names[i]);
}
return ret;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
VLOG(2) << "Prepare to skip " << attr.size()
<< " var(s): " << GetDebugString(attr);
attrs[kSkipEagerDeletionVars] = std::move(attr);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
const OpVariant &grad_op) {
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
framework::BlockDesc *grad_block) {
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
const OpVariant &bwd_op) {
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
// Find all skippable variables in forward while_op
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (IsSkippableVar(in_arg_name, grad_block)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (IsSkippableVar(out_arg_name, grad_block)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
forward_skip_vars.end()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto &fwd_input = fwd_op.Inputs().at(kX);
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
std::unordered_set<std::string> backward_skip_vars;
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
}
backward_skip_vars.insert(in_grads[i]);
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
backward_skip_vars.end()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
while_ops->emplace_back(op);
} else if (op->Type() == "while_grad") {
while_grad_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
"There are extra while_grad ops in the graph or program");
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
if (while_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
while_ops->begin(), while_ops->end());
for (auto &bwd_op : *while_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : while_op_set) {
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched while ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward while op.");
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op);
}
PADDLE_ENFORCE(while_op_set.empty(),
"There are not matched while_grad op in graph.");
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "while") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "while_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(while_ops.size());
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,30 @@ ...@@ -14,19 +14,30 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace details {
class EagerDeletionPass : public ir::Pass { static constexpr char kStepBlock[] = "sub_block";
protected: static constexpr char kCondition[] = "Condition";
std::unique_ptr<ir::Graph> ApplyImpl( static constexpr char kStepScopes[] = "StepScopes";
std::unique_ptr<ir::Graph> graph) const override; static constexpr char kX[] = "X";
}; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
} // namespace details void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
} // namespace framework int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
} // namespace operators
} // namespace paddle } // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['CPU_NUM'] = '2'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['FLAGS_fast_eager_deletion_mode'] = '1'
import unittest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
from paddle.fluid.backward import append_backward
import paddle.fluid.compiler as compiler
import numpy
import multiprocessing
class TestEagerDeletionWhileOpBase(unittest.TestCase):
def test_main(self):
places = [core.CPUPlace(), ]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for p in places:
for with_data_parallel in [False, True]:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_main(p, with_data_parallel)
def run_main(self, place, with_data_parallel):
self.place = place
self.with_data_parallel = with_data_parallel
if not core.is_compiled_with_cuda() and isinstance(self.place,
core.CUDPlace):
return
if isinstance(self.place, core.CUDAPlace):
device_cnt = core.get_cuda_device_count(
) if self.with_data_parallel else 1
else:
device_cnt = int(
os.environ['CPU_NUM'],
multiprocessing.cpu_count()) if self.with_data_parallel else 1
d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32')
d1 = layers.data(
"d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i)
layers.array_write(d1, i, array=data_array)
i = layers.increment(i)
layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
d = layers.reshape(d, shape=[10])
prev = layers.reshape(prev, shape=[10])
result = layers.sums(input=[d, prev])
i = layers.increment(x=i, in_place=True)
layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond)
with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
d2 = layers.reshape(d2, shape=[10])
prev2 = layers.reshape(prev2, shape=[10])
result2 = layers.sums(input=[d2, prev2])
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
sum_result.persistable = True
tmp = layers.unsqueeze(sum_result, axes=[0])
tmp = layers.expand(tmp, expand_times=[10, 1])
fc = layers.fc(tmp, size=256)
loss = layers.mean(sum_result)
optim = fluid.optimizer.Adam(learning_rate=1e-3)
optim.minimize(loss)
exe = Executor(self.place)
exe.run(fluid.default_startup_program())
prog = compiler.CompiledProgram(fluid.default_main_program())
if self.with_data_parallel:
prog = prog.with_data_parallel()
for _ in range(5):
d = []
for i in range(3):
tmp = numpy.random.random(size=[10]).astype('float32')
if not self.with_data_parallel:
d.append(tmp)
else:
d.append(numpy.array([tmp] * device_cnt))
outs = exe.run(program=prog,
feed={'d0': d[0],
'd1': d[1],
'd2': d[2]},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ['FLAGS_memory_fraction_of_eager_deletion'] = "0.55"
os.environ[
'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册