提交 612e1a31 编写于 作者: S sneaxiy

modification

上级 d0b2453e
......@@ -23,8 +23,6 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
......
......@@ -89,11 +89,6 @@ class OpHandleBase {
ir::Node *Node() { return node_; }
const std::map<platform::Place, platform::DeviceContext *>
&GetDeviceContexts() const {
return dev_ctxes_;
}
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
......
......@@ -69,15 +69,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
std::string Name() const override { return "reference_count"; }
// protected:
protected:
void RunImpl() override {
auto *exec_scope_ = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<LoDTensor *> tensors;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope_->FindVar(name);
auto *var = exec_scope->FindVar(name);
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
if (it->second.fetch_sub(1) <= 1) {
......@@ -91,8 +91,8 @@ class ReferenceCountOpHandle : public OpHandleBase {
}
private:
void ClearTensors(const std::vector<LoDTensor *> &tensors) const {
auto *gc = dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_);
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
......
......@@ -128,12 +128,10 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
auto it = compute_ref_cnt_map.find(op.get());
if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(std::move(op));
new_all_ops.emplace_back(std::unique_ptr<OpHandleBase>(it->second));
} else {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(it->second);
}
}
......
......@@ -37,9 +37,11 @@ int kProgramId = -1;
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog),
block_id_(block_id),
ref_cnts_(GetNonPersistableReferenceCount<int>(prog, block_id)) {}
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_);
}
}
ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
......@@ -331,8 +333,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
}
std::shared_ptr<std::vector<framework::LoDTensor*>> erase_tensors(
new std::vector<framework::LoDTensor*>());
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
......@@ -353,7 +353,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr) {
std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
......@@ -395,7 +394,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
}
#endif
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......@@ -403,10 +401,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}
if (gc != nullptr)
if (gc != nullptr) {
gc->Wait();
else
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (local_scope != scope) {
scope->DeleteScope(local_scope);
......
......@@ -28,8 +28,6 @@ namespace paddle {
namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
int64_t GetEagerDeletionThreshold();
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
......
......@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......
......@@ -29,6 +29,10 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace paddle {
namespace framework {
......
......@@ -32,7 +32,7 @@ DEFINE_bool(
"slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_GB, -1.0,
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
......@@ -40,9 +40,9 @@ namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_GB < 0
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_GB *
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
......
......@@ -36,8 +36,6 @@ limitations under the License. */
#endif
#include "unsupported/Eigen/CXX11/Tensor"
DECLARE_bool(clear_gpu_memory_when_unused);
namespace paddle {
namespace platform {
......
......@@ -122,7 +122,7 @@ def __bootstrap__():
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
"dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_GB'
"dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册