提交 732fa00e 编写于 作者: S sneaxiy

disable gc in recurrent_op currently

test=develop
上级 d0f8d94c
......@@ -80,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
const std::vector<std::string>& keep_vars, bool force_disable_gc)
: prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
global_ref_cnts_ =
GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
}
}
......@@ -189,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif
auto ctx = Prepare(pdesc, block_id);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}
......@@ -362,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) {
std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
program, block_id, skip_ref_cnt_vars, force_disable_gc));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
......@@ -375,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
bool force_disable_gc) {
PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d",
......@@ -385,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
for (auto& bid : block_ids) {
ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid);
ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
force_disable_gc);
} else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
force_disable_gc);
}
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid);
......@@ -414,7 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
if (max_memory_size >= 0) {
// FIXME(zjl): recurrent_op is rather complex, we would
// disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
......@@ -432,7 +439,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA
}
#endif
if (gc && keep_kids) {
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
}
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -30,7 +32,8 @@ namespace framework {
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
std::vector<std::string>(),
bool force_disable_gc = false);
~ExecutorPrepareContext();
......@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
const framework::ProgramDesc& prog_;
size_t block_id_;
bool force_disable_gc_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> global_ref_cnts_;
......@@ -66,7 +70,10 @@ class Executor {
* Scope
*/
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true);
bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
// This API is very slow.
void Run(const ProgramDesc& program, Scope* scope,
......@@ -79,12 +86,14 @@ class Executor {
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
std::vector<std::string>(),
bool force_disable_gc = false);
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>());
std::vector<std::vector<std::string>>(),
bool force_disable_gc = false);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
......@@ -270,7 +270,9 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/);
false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool
platform::DeviceContextPool &pool =
......@@ -385,7 +387,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/);
false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished ";
......
......@@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<const platform::Place &>())
.def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) {
int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
pybind11::gil_scoped_release release;
self.Run(prog, scope, block_id, create_local_scope, create_vars);
self.Run(prog, scope, block_id, create_local_scope, create_vars,
fetch_vars);
});
m.def("init_gflags", framework::InitGflags);
......
......@@ -590,7 +590,7 @@ class Executor(object):
fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope)
exe.run(program.desc, scope, 0, True, True)
exe.run(program.desc, scope, 0, True, True, fetch_var_name)
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy:
outs = as_numpy(outs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册