未验证 提交 89966525 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish reference count pass (#21324)

* fix ref_cnt pass, test=develop

* add cpp unittests to reference_count_pass, test=develop

* follow comments, test=develop
上级 b39f9476
......@@ -25,7 +25,7 @@ if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor place)
dynload_cuda variable_visitor place device_memory_aligment)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
......@@ -46,7 +46,7 @@ else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor place)
variable_visitor place device_memory_aligment)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
......@@ -103,4 +103,14 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})
if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass)
endif()
if (WITH_NGRAPH)
target_link_libraries(build_strategy ngraph_subgraph_pass)
endif()
......@@ -12,3 +12,5 @@ cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handl
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
cc_test(test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op)
......@@ -202,34 +202,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
details::OpHandleBase *op, size_t scope_idx) {
std::queue<details::OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
}
return nullptr;
}
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
enum LastLiveOpSearchStatus { kSuccess, kFailure };
static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
......@@ -237,22 +210,7 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<details::OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}
auto candidates = var->PendingOps();
// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
......@@ -262,37 +220,41 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<details::ComputationOpHandle *> computation_op;
std::unordered_set<details::ComputationOpHandle *> computation_ops;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op && compute_op->GetScopeIdx() == scope_idx) {
computation_ops.emplace(compute_op);
} else {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}
auto *generated_op =
dynamic_cast<details::ComputationOpHandle *>(var->GeneratedOp());
if (generated_op && generated_op->GetScopeIdx() == scope_idx) {
computation_ops.emplace(generated_op);
}
}
// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
if (computation_ops.empty() ||
ShrinkNoNeedBufferVarOpDependency(var_name, &computation_ops)) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");
PADDLE_ENFORCE_EQ(
computation_ops.empty(), false,
platform::errors::InvalidArgument("Computation ops should not be empty"));
// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
return shrink_func(computation_ops);
}
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
......@@ -344,47 +306,45 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
if ((*iter)->Node()->IsCtrlVar()) {
break;
}
PADDLE_ENFORCE_EQ(
var_handles.empty(), false,
platform::errors::InvalidArgument("Variable %s not found", var_name));
auto last_ver_var = var_handles.back();
VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);
// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(1) << "Cannot find last live ops of variable " << var_name
<< " in scope " << (*iter)->scope_idx();
break;
}
if (last_ver_var->Node()->IsCtrlVar()) {
continue;
}
if (status == LastLiveOpSearchStatus::kShouldPrecede) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
last_ver_var, i, var_name, shrink_func, &status);
// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(1) << "Cannot find last live ops of variable " << var_name
<< " in scope " << last_ver_var->scope_idx();
continue;
}
PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
PADDLE_ENFORCE_EQ(
status, LastLiveOpSearchStatus::kSuccess,
platform::errors::InvalidArgument("status must be success"));
PADDLE_ENFORCE_EQ(result.empty(), false,
platform::errors::NotFound(
"Last living ops of %s cannot be empty", var_name));
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
var_infos[i][var_name].reset(
new MemOptVarInfo(var_name, result.size()));
auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name];
last_live_ops_of_var.set_var(*iter);
*(last_live_ops_of_var.mutable_ops()) = std::move(result);
break;
std::string last_live_ops_log_str;
for (auto &each_ret : result) {
last_live_ops_log_str += (" " + each_ret->GetOp()->Type());
}
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name
<< " : " << last_live_ops_log_str;
// Seldomly, all preceding trying failed.
// Just skip this corner case
var_infos[i][var_name].reset(new MemOptVarInfo(var_name, result.size()));
auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name];
last_live_ops_of_var.set_var(last_ver_var);
*(last_live_ops_of_var.mutable_ops()) = std::move(result);
}
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"
USE_OP(scale);
USE_OP(elementwise_mul);
USE_OP(elementwise_add);
USE_OP(elementwise_add_grad);
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace framework {
static std::vector<platform::Place> CreatePlaces(size_t num, bool use_cuda) {
std::vector<platform::Place> result;
result.reserve(num);
for (size_t i = 0; i < num; ++i) {
if (use_cuda) {
result.emplace_back(platform::CUDAPlace(i));
} else {
result.emplace_back(platform::CPUPlace());
}
}
return result;
}
static void NewVar(BlockDesc *block, const std::string &name,
const std::vector<int64_t> &shape) {
auto *var_desc = block->Var(name);
var_desc->SetShape(shape);
}
static void AppendOp(BlockDesc *block, const std::string &type,
VariableNameMap inputs, VariableNameMap outputs,
AttributeMap attrs) {
auto &op_info = OpInfoMap::Instance().Get(type);
if (op_info.Checker()) {
op_info.Checker()->Check(&attrs);
}
auto *op = block->AppendOp();
op->SetType(type);
for (auto &pair : inputs) {
op->SetInput(pair.first, pair.second);
}
for (auto &pair : outputs) {
op->SetOutput(pair.first, pair.second);
for (auto &var_name : pair.second) {
if (!block->FindVarRecursive(var_name)) {
NewVar(block, var_name, {});
}
}
}
op->SetAttrMap(attrs);
op->InferVarType(block);
op->InferShape(*block);
}
class ReferenceCountPassTestHelper {
public:
ReferenceCountPassTestHelper(const ProgramDesc &program, bool use_cuda)
: graph_(program) {
details::BuildStrategy build_strategy;
build_strategy.enable_inplace_ = false;
build_strategy.memory_optimize_ = false;
FLAGS_eager_delete_tensor_gb = -1;
details::ExecutionStrategy exec_strategy;
exec_strategy.use_cuda_ = use_cuda;
executor_.reset(new ParallelExecutor(CreatePlaces(1, use_cuda), {}, "",
&scope_, {}, exec_strategy,
build_strategy, &graph_));
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_);
ref_cnt_pass->Apply(&graph_);
}
bool IsLastLivedOps(const std::string &name,
std::vector<std::string> ops) const {
std::sort(ops.begin(), ops.end());
return LastLivedOpTypes(name) == ops;
}
std::vector<OperatorBase *> LastLivedOps(const std::string &name) const {
auto &ops = last_live_ops_of_vars_[0].at(name).ops();
std::vector<OperatorBase *> ret;
for (auto *op : ops) {
ret.emplace_back(op->GetOp());
}
return ret;
}
private:
std::vector<std::string> LastLivedOpTypes(const std::string &name) const {
auto iter = last_live_ops_of_vars_[0].find(name);
std::vector<std::string> ret;
if (iter != last_live_ops_of_vars_[0].end()) {
for (auto *op : iter->second.ops()) {
ret.emplace_back(op->GetOp()->Type());
}
}
std::sort(ret.begin(), ret.end());
return ret;
}
private:
ir::Graph graph_;
Scope scope_;
std::unique_ptr<ParallelExecutor> executor_;
ir::MemOptVarInfoMapList mem_opt_var_infos_;
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars_;
};
TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{{3, 4, 5}};
/**
* The network is:
*
* x0 = fluid.layer.data(...)
* x1 = scale(x0, scale=1)
* x2 = scale(x1, scale=2)
* x3 = elementwise_mul(x1, x2)
* scale(x3, out=x1, scale=3) # produce a new version of x1
* x4, x5 = elementwise_add_grad(dout=x3, x=x2, y=x1)
* x6 = elementwise_mul(x4, x5)
* x7 = elementwise_add(x5, x5)
*/
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
std::string x3 = "x3";
std::string x4 = "x4";
std::string x5 = "x5";
std::string x6 = "x6";
std::string x7 = "x7";
NewVar(block, x0, shape);
AppendOp(block, "scale", {{"X", {x0}}}, {{"Out", {x1}}}, {{"scale", 1.0f}});
AppendOp(block, "scale", {{"X", {x1}}}, {{"Out", {x2}}}, {{"scale", 2.0f}});
AppendOp(block, "elementwise_mul", {{"X", {x1}}, {"Y", {x2}}},
{{"Out", {x3}}}, {});
AppendOp(block, "scale", {{"X", {x3}}}, {{"Out", {x1}}}, {{"scale", 3.0f}});
AppendOp(block, "elementwise_add_grad",
{{GradVarName("Out"), {x3}}, {"X", {x2}}, {"Y", {x1}}},
{{GradVarName("X"), {x4}}, {GradVarName("Y"), {x5}}}, {});
AppendOp(block, "elementwise_mul", {{"X", {x4}}, {"Y", {x5}}},
{{"Out", {x6}}}, {});
AppendOp(block, "elementwise_add", {{"X", {x5}}, {"Y", {x5}}},
{{"Out", {x7}}}, {});
std::vector<bool> use_cuda_list{false};
#ifdef PADDLE_WITH_CUDA
use_cuda_list.push_back(true);
#endif
for (auto use_cuda : use_cuda_list) {
ReferenceCountPassTestHelper helper(program, use_cuda);
ASSERT_TRUE(helper.IsLastLivedOps(x0, {"scale"}));
ASSERT_EQ(
boost::get<float>(helper.LastLivedOps(x0)[0]->Attrs().at("scale")),
1.0f);
ASSERT_TRUE(helper.IsLastLivedOps(x1, {"scale"}));
ASSERT_EQ(
boost::get<float>(helper.LastLivedOps(x1)[0]->Attrs().at("scale")),
3.0f);
ASSERT_TRUE(helper.IsLastLivedOps(x2, {"elementwise_mul"}));
ASSERT_TRUE(helper.IsLastLivedOps(x3, {"elementwise_add_grad"}));
ASSERT_TRUE(helper.IsLastLivedOps(x4, {"elementwise_mul"}));
ASSERT_TRUE(
helper.IsLastLivedOps(x5, {"elementwise_mul", "elementwise_add"}));
ASSERT_TRUE(helper.IsLastLivedOps(x6, {"elementwise_mul"}));
ASSERT_TRUE(helper.IsLastLivedOps(x7, {"elementwise_add"}));
}
}
} // namespace framework
} // namespace paddle
......@@ -267,6 +267,26 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return graph;
}
/**
* NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python,
* set BuildStrategy.memory_optimize according to whether gc is enabled.
* If gc is enabled, BuildStrategy.memory_optimize = False.
* If gc is disabled, BuildStrategy.memory_optimize = True.
* This is because gc+memory_optimize is worse than gc only.
*
* As an option, users can enable BuildStrategy.memory_optimize forcely
* by setting True, and disable it forcely by setting False.
*/
bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0);
if (!build_strategy_.memory_optimize_) {
build_strategy_.memory_optimize_ = !is_gc_enabled;
}
bool need_mem_opt = build_strategy_.enable_inplace_ ||
build_strategy_.memory_optimize_.get() || is_gc_enabled;
if (!need_mem_opt) return graph;
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass");
......@@ -288,21 +308,6 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
"build_strategy.enable_inplace = True";
}
/**
* NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python,
* set BuildStrategy.memory_optimize according to whether gc is enabled.
* If gc is enabled, BuildStrategy.memory_optimize = False.
* If gc is disabled, BuildStrategy.memory_optimize = True.
* This is because gc+memory_optimize is worse than gc only.
*
* As an option, users can enable BuildStrategy.memory_optimize forcely
* by setting True, and disable it forcely by setting False.
*/
bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0);
if (!build_strategy_.memory_optimize_) {
build_strategy_.memory_optimize_ = !is_gc_enabled;
}
if (build_strategy_.memory_optimize_.get()) {
auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get(
"buffer_shared_cross_op_memory_reuse_pass");
......
......@@ -84,7 +84,7 @@ if (WITH_DGC)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册