未验证 提交 1202d3fc 编写于 作者: Z Zeng Jinle 提交者: GitHub

Refine model gpu memory (#16993)

* speedup gc and inplace softmax_with_cross_entropy_grad
test=develop

* refine models gpu mem
Merge skip vars and warning messages of mem opt
remove relu mem opt
test=develop

* follow comments
test=develop
上级 af8a041b
...@@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_ ...@@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
...@@ -124,4 +126,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -124,4 +126,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass) fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass)
...@@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path)); viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
} }
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
AppendPass("record_skip_memory_opt_vars_pass");
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
VLOG(10) << "Add sequential_execution_pass"; VLOG(10) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
...@@ -341,3 +344,4 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -341,3 +344,4 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass); USE_PASS(expected_kernel_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass);
...@@ -34,7 +34,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -34,7 +34,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
AtomicReferenceCountMap *ref_cnts) AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
var_names_(var_names), var_names_(var_names.begin(), var_names.end()),
gc_(gc), gc_(gc),
ref_cnts_(ref_cnts) { ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
...@@ -37,6 +40,13 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -37,6 +40,13 @@ class EagerDeletionOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
/**
* Currently, EagerDeletionOpHandle has the highest priority.
* This priority settings speed up gc 15% in Transformer
* V100 8-GPU model.
*/
Priority GetPriority() const override { return kHighest; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -44,7 +54,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -44,7 +54,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages); void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_; const Scope *scope_;
std::unordered_set<std::string> var_names_; std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include <memory> #include <memory>
#include <queue>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -131,32 +132,53 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -131,32 +132,53 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( bool FastThreadedSSAGraphExecutor::RunOp(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
OpHandleBase *op, size_t *complete) {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_;
this->pool_.enqueue([=] {
OpHandleBase *op_to_run = op;
size_t complete = 0;
while (op_to_run != nullptr) {
try { try {
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op_to_run->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
++complete; ++(*complete);
return true;
} catch (...) { } catch (...) {
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
--remaining_; --remaining_;
complete_q->Push(-1UL); complete_q->Push(-1UL);
return false;
}
}
void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_;
this->pool_.enqueue([=] {
std::queue<OpHandleBase *> op_queue;
op_queue.push(op);
size_t complete = 0;
while (!op_queue.empty()) {
OpHandleBase *op_to_run = op_queue.front();
op_queue.pop();
if (!RunOp(op_to_run, complete_q, &complete)) {
return; return;
} }
auto &outputs = op_to_run->Outputs(); auto &outputs = op_to_run->Outputs();
op_to_run = nullptr; op_to_run = nullptr;
for (auto &output : outputs) { for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) { for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op); std::atomic<int> &deps = op_deps->at(pending_op);
if (deps.fetch_sub(1) == 1) { // pending_op ready if (deps.fetch_sub(1) != 1) continue;
// NOTE(zjl): op with highest priority should run
// first without switching to another thread.
if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
op_queue.push(pending_op);
} else {
if (op_to_run == nullptr) { if (op_to_run == nullptr) {
op_to_run = pending_op; op_to_run = pending_op;
} else { } else {
...@@ -165,6 +187,8 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -165,6 +187,8 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
} }
} }
} }
if (op_to_run != nullptr) op_queue.push(op_to_run);
} }
--remaining_; --remaining_;
complete_q->Push(complete); complete_q->Push(complete);
......
...@@ -60,6 +60,10 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -60,6 +60,10 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_; ::ThreadPool pool_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete);
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q); const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
......
...@@ -78,6 +78,13 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT ...@@ -78,6 +78,13 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT
"elementwise_add", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad",
}; };
// FIXME(zjl): Shapes of in-out of some ops are exactly the same,
// but the static size during compiling time would be wrong.
// Use a flag to indicate such ops. Please fix me when found a better way.
static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT
"reshape2"
};
// clang-format on // clang-format on
namespace paddle { namespace paddle {
...@@ -303,7 +310,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -303,7 +310,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name; VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name;
if (view_.InSkipSet(in_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name);
continue;
}
if (view_.InSkipSet(out_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name);
continue;
}
if (var_nodes_[in_var_name].back() != in_node) { if (var_nodes_[in_var_name].back() != in_node) {
VLOG(4) << "SKIP since " << in_var_name VLOG(4) << "SKIP since " << in_var_name
...@@ -318,21 +334,26 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -318,21 +334,26 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
<< out_var_name << " are the same"; << out_var_name << " are the same";
} else if (!NodeCanReused(in_node)) { } else if (!NodeCanReused(in_node)) {
can_replace = false; can_replace = false;
VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused"; VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused";
} else if (!NodeCanReused(out_node)) { } else if (!NodeCanReused(out_node)) {
can_replace = false; can_replace = false;
VLOG(4) << "SKIP: Output variable " << out_var_name VLOG(4) << "SKIP: Output variable " << out_var_name
<< " cannot be reused"; << " cannot be reused";
} else if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
can_replace = false;
VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType()
<< " does not match Output type : " << out_node->Var()->GetType();
} else if (details::NodeSize(*in_node->Var()) != } else if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var())) { details::NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
can_replace = false; can_replace = false;
VLOG(4) << "SKIP: Input and Output varialbe size not match"; VLOG(4) << "SKIP: Input and Output varialbe size not match";
} }
if (!can_replace) continue; if (!can_replace) continue;
// 2. there is no external pending op on the input node // 2. If the variable is the input of muliple ops, we need to make sure
// if (view_.PendingOpsOnVar(in_node).size() > 1) { // current op has dependecny on other ops use the same variable
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) { if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
VLOG(4) << string::Sprintf( VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency." "Skiped pair %s => %s. %s input has external dependency."
...@@ -341,17 +362,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -341,17 +362,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue; continue;
} }
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if (view_.InSkipSet(out_node->Name())) {
VLOG(4) << string::Sprintf(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle",
out_var_name, in_var_name, op->Name());
continue;
}
// Debug Interface. Which would be skipped by the pass. // Debug Interface. Which would be skipped by the pass.
if (out_node->Name() == FLAGS_memory_optimize_debug) { if (out_node->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug="
...@@ -424,6 +434,9 @@ void GraphView::TopoSort(ir::Graph* graph) { ...@@ -424,6 +434,9 @@ void GraphView::TopoSort(ir::Graph* graph) {
for (auto& node : nodes) { for (auto& node : nodes) {
if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) { if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) {
all_ops_checked = false; all_ops_checked = false;
LOG(WARNING)
<< "Node " << node->Name() << " has not been checked. "
<< "Maybe some passes have not handle node dependency rightly.";
break; break;
} }
} }
...@@ -519,16 +532,22 @@ void GraphView::Build(ir::Graph* g) { ...@@ -519,16 +532,22 @@ void GraphView::Build(ir::Graph* g) {
// resolve data harzards depends on the var nodes in right order. // resolve data harzards depends on the var nodes in right order.
TopoSort(g); TopoSort(g);
// fill the skip_set_
PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars));
auto& mem_opt_whitelist = g->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
// 2. track the nodes which used by parameter server. // 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer // these node can not be inplaced, otherwise trainer
// pserver can not find each other name. // pserver can not find each other name.
auto update_skip_set = [&](ir::Node* node) { auto update_skip_set = [&](ir::Node* node) {
for (auto& in : node->inputs) { for (auto& in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) dup_nodes_.emplace(in->Name()); if (in->IsVar() && in->Var() != nullptr) {
skip_set_.emplace(in->Name());
}
} }
for (auto& out : node->outputs) { for (auto& out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name());
dup_nodes_.emplace(out->Name());
} }
}; };
for (auto& node : g->Nodes()) { for (auto& node : g->Nodes()) {
...@@ -545,7 +564,7 @@ void GraphView::Build(ir::Graph* g) { ...@@ -545,7 +564,7 @@ void GraphView::Build(ir::Graph* g) {
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; } const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::InSkipSet(const std::string& var) const { bool GraphView::InSkipSet(const std::string& var) const {
return dup_nodes_.count(var); return skip_set_.count(var);
} }
} // namespace details } // namespace details
......
...@@ -57,7 +57,7 @@ class GraphView { ...@@ -57,7 +57,7 @@ class GraphView {
private: private:
std::vector<ir::Node*> ops_; std::vector<ir::Node*> ops_;
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes std::unordered_set<std::string> skip_set_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_; std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
std::unordered_map<ir::Node*, uint32_t> op_level_; std::unordered_map<ir::Node*, uint32_t> op_level_;
}; };
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -30,6 +31,11 @@ namespace paddle { ...@@ -30,6 +31,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
typedef std::unordered_set<std::string> MemOptSkipVars;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph); std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize. // NOTE(dzh): A ordered set for node reuse in memory optimize.
......
...@@ -45,8 +45,7 @@ namespace framework { ...@@ -45,8 +45,7 @@ namespace framework {
namespace details { namespace details {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const { void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
auto nodes = graph->Nodes(); CollectSkipVarsSet(graph);
CollectSkipVarsSet(nodes);
cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_.reset(new details::ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis(); cfg_->LiveVariableAnalysis();
...@@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { ...@@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
} }
} }
void MemoryOptimizePass::CollectSkipVarsSet( void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
const std::unordered_set<ir::Node*>& nodes) const { // fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
auto update_skip_set = [&](OpDesc* op_desc) { auto update_skip_set = [&](OpDesc* op_desc) {
auto inputs = op_desc->InputArgumentNames(); auto inputs = op_desc->InputArgumentNames();
auto outputs = op_desc->OutputArgumentNames(); auto outputs = op_desc->OutputArgumentNames();
skip_set_.insert(inputs.begin(), inputs.end()); skip_set_.insert(inputs.begin(), inputs.end());
skip_set_.insert(outputs.begin(), outputs.end()); skip_set_.insert(outputs.begin(), outputs.end());
}; };
auto nodes = graph->Nodes();
for (auto& op : nodes) { for (auto& op : nodes) {
if (!op->IsOp() || op->Op() == nullptr) continue; if (!op->IsOp() || op->Op() == nullptr) continue;
auto* op_desc = op->Op(); auto* op_desc = op->Op();
......
...@@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass { ...@@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass {
// 1. scan op with subblock and collect the output/input vars. // 1. scan op with subblock and collect the output/input vars.
// while, while_grad, conditional_block // while, while_grad, conditional_block
// 2. scan distributed ops and collect the output/input vars // 2. scan distributed ops and collect the output/input vars
void CollectSkipVarsSet(const std::unordered_set<ir::Node*>&) const; // 3. op_role_vars
void CollectSkipVarsSet(ir::Graph* graph) const;
private: private:
// Reuse Node Pool, Owned. // Reuse Node Pool, Owned.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
...@@ -31,6 +33,13 @@ constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@"; ...@@ -31,6 +33,13 @@ constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
/**
* NOTE(zjl): Some op should have higher priority than others.
* The higher priority op would run first without switching
* threads in Executor.
*/
enum Priority { kHighest = 0, kNormal = 1 };
// Owned by `node`. No need to be deleted explicitly. // Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) { explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this); node_->WrappedBy(this);
...@@ -40,6 +49,8 @@ class OpHandleBase { ...@@ -40,6 +49,8 @@ class OpHandleBase {
std::string DebugString() const; std::string DebugString() const;
virtual Priority GetPriority() const { return kNormal; }
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
void Run(bool use_cuda); void Run(bool use_cuda);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(graph, &skip_vars);
}
void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph,
MemOptSkipVars* skip_vars) const {
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get e) {
}
}
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::details::RecordSkipMemoryOptVarsPass);
...@@ -107,6 +107,15 @@ void GarbageCollector::Add(Container &&objs) { ...@@ -107,6 +107,15 @@ void GarbageCollector::Add(Container &&objs) {
template <typename Container, typename Callback> template <typename Container, typename Callback>
void GarbageCollector::Add(Container &&objs, Callback &&callback) { void GarbageCollector::Add(Container &&objs, Callback &&callback) {
// Special case when FLAGS_eager_delete_tensor_gb=0.0
// It speeds up GC about 2~3%.
if (max_memory_size_ <= 1) {
callback();
auto *container = new Container(std::move(objs));
ClearCallback([container] { delete container; });
return;
}
GarbageQueue *garbage_queue = nullptr; GarbageQueue *garbage_queue = nullptr;
{ {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h" #include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) { ...@@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData(&prog); FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g)); g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op"); auto op_node = GetNodeFromGraph(g.get(), "single_op");
...@@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) { ...@@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData(&prog); FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g)); g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op"); auto op_node = GetNodeFromGraph(g.get(), "single_op");
...@@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { ...@@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass()); std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
...@@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { ...@@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass()); std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -115,6 +117,14 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ...@@ -115,6 +117,14 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
ctx->GetInputDim("Scale")); ctx->GetInputDim("Scale"));
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
class AffineChannelGradMaker : public framework::SingleGradOpDescMaker { class AffineChannelGradMaker : public framework::SingleGradOpDescMaker {
...@@ -217,7 +227,6 @@ class AffineChannelGradKernel : public framework::OpKernel<T> { ...@@ -217,7 +227,6 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
: dims[dims.size() - 1]; : dims[dims.size() - 1];
int HxW = x->numel() / N / C; int HxW = x->numel() / N / C;
auto* x_d = x->data<T>();
auto* dy_d = dy->data<T>(); auto* dy_d = dy->data<T>();
auto* scale_d = scale->data<T>(); auto* scale_d = scale->data<T>();
ConstEigenVectorArrayMap<T> scale_e(scale_d, C); ConstEigenVectorArrayMap<T> scale_e(scale_d, C);
...@@ -242,6 +251,7 @@ class AffineChannelGradKernel : public framework::OpKernel<T> { ...@@ -242,6 +251,7 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
} }
// compute dscale and dbias // compute dscale and dbias
if (dscale && dbias) { if (dscale && dbias) {
auto* x_d = x->data<T>();
dy_d = dy->data<T>(); dy_d = dy->data<T>();
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C); ConstEigenArrayMap<T> x_e(x_d, HxW, C);
...@@ -270,6 +280,7 @@ class AffineChannelGradKernel : public framework::OpKernel<T> { ...@@ -270,6 +280,7 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
} }
// compute dscale and dbias // compute dscale and dbias
if (dscale && dbias) { if (dscale && dbias) {
auto* x_d = x->data<T>();
ConstEigenArrayMap<T> x_e(x_d, C, num); ConstEigenArrayMap<T> x_e(x_d, C, num);
dscale_e = (x_e * dy_e).rowwise().sum(); dscale_e = (x_e * dy_e).rowwise().sum();
dbias_e = dy_e.rowwise().sum(); dbias_e = dy_e.rowwise().sum();
...@@ -278,6 +289,33 @@ class AffineChannelGradKernel : public framework::OpKernel<T> { ...@@ -278,6 +289,33 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
} }
}; };
class AffineChannelNoNeedBufferVarsInference
: public framework::NoNeedBufferVarsInference {
public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private:
inline bool HasInput(const std::string& name) const {
auto& inputs = Inputs();
auto iter = inputs.find(name);
if (iter == inputs.end() || iter->second.empty()) {
return false;
} else {
return iter->second[0] != framework::kEmptyVarName;
}
}
public:
std::unordered_set<std::string> operator()() const {
if (!HasInput(framework::GradVarName("Scale")) &&
!HasInput(framework::GradVarName("Bias"))) {
return {"X"};
} else {
return {};
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -286,7 +324,8 @@ using CPU = paddle::platform::CPUDeviceContext; ...@@ -286,7 +324,8 @@ using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp, REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker, ops::AffineChannelGradMaker); ops::AffineChannelOpMaker, ops::AffineChannelGradMaker);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad); REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad,
ops::AffineChannelNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
ops::AffineChannelKernel<CPU, double>); ops::AffineChannelKernel<CPU, double>);
......
...@@ -128,14 +128,13 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> { ...@@ -128,14 +128,13 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout")); framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dims = x->dims(); auto dims = dy->dims();
const int num = x->numel(); const int num = dy->numel();
int N = dims[0]; int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1] int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1]; : dims[dims.size() - 1];
int HxW = num / N / C; int HxW = num / N / C;
const T* x_d = x->data<T>();
const T* dy_d = dy->data<T>(); const T* dy_d = dy->data<T>();
const T* s_d = scale->data<T>(); const T* s_d = scale->data<T>();
...@@ -155,6 +154,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> { ...@@ -155,6 +154,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
dy_d, s_d, nullptr, C, HxW, num, dx_d); dy_d, s_d, nullptr, C, HxW, num, dx_d);
} }
if (dscale && dbias) { if (dscale && dbias) {
const T* x_d = x->data<T>();
AffineChannelScaleBiasGradientCUDAKernel< AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNCHW><<<grid2, block, 0, T, block, framework::DataLayout::kNCHW><<<grid2, block, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(
...@@ -167,6 +167,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> { ...@@ -167,6 +167,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
dy_d, s_d, nullptr, C, HxW, num, dx_d); dy_d, s_d, nullptr, C, HxW, num, dx_d);
} }
if (dscale && dbias) { if (dscale && dbias) {
const T* x_d = x->data<T>();
AffineChannelScaleBiasGradientCUDAKernel< AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNHWC><<<grid2, block, 0, T, block, framework::DataLayout::kNHWC><<<grid2, block, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -225,6 +228,15 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -225,6 +228,15 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class SoftmaxWithCrossEntropyGradInplaceInference
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const {
return {{"Softmax", framework::GradVarName("Logits")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -233,7 +245,8 @@ namespace ops = paddle::operators; ...@@ -233,7 +245,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker); ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad); ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInference);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>, ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>); ops::SoftmaxWithCrossEntropyKernel<double>);
......
...@@ -454,8 +454,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -454,8 +454,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>(); context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
framework::TensorCopy(*context.Input<Tensor>("Softmax"), context.GetPlace(), const Tensor* softmax = context.Input<Tensor>("Softmax");
if (logit_grad != softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad); context.device_context(), logit_grad);
}
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->data<T>();
int rank = logit_grad->dims().size(); int rank = logit_grad->dims().size();
......
...@@ -68,7 +68,12 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,12 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
const Tensor* labels = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Label");
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
if (logit_grad != softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
}
int rank = logit_grad->dims().size(); int rank = logit_grad->dims().size();
const int class_num = logit_grad->dims()[rank - 1]; const int class_num = logit_grad->dims()[rank - 1];
......
...@@ -158,6 +158,8 @@ class CudnnHolder { ...@@ -158,6 +158,8 @@ class CudnnHolder {
if (required_workspace_len > WorkspaceSize()) { if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len); ReallocateWorkspace(required_workspace_len);
} }
VLOG(2) << "Cudnn workspace size: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
cudnn_func(WorkspacePtr()); cudnn_func(WorkspacePtr());
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -33,6 +34,7 @@ void BindConstValue(pybind11::module* m) { ...@@ -33,6 +34,7 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName", m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; }); [] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; }); m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; });
auto op_proto_and_checker_maker = auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker"); m->def_submodule("op_proto_and_checker_maker");
......
...@@ -84,6 +84,12 @@ void BindGraph(py::module *m) { ...@@ -84,6 +84,12 @@ void BindGraph(py::module *m) {
return self.Set(attr_name, return self.Set(attr_name,
new std::unordered_set<const Node *>(attr)); new std::unordered_set<const Node *>(attr));
}) })
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<std::string> &attr) {
return self.Set(attr_name,
new std::unordered_set<std::string>(attr));
})
.def("erase", &Graph::Erase) .def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference) .def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node", .def("create_var_node",
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import multiprocessing import multiprocessing
import os import os
import six import six
...@@ -152,6 +153,39 @@ class CompiledProgram(object): ...@@ -152,6 +153,39 @@ class CompiledProgram(object):
else: else:
self._places = None self._places = None
self._build_strategy.is_distribution = _is_pserver_mode(self._program) self._build_strategy.is_distribution = _is_pserver_mode(self._program)
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# memory_optimize and enable_inplace default are True, but we can disable them on purpose
if self._program:
if self._program._is_mem_optimized:
self._build_strategy.memory_optimize = False
self._build_strategy.enable_inplace = False
elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace:
# remind the user to try our memmory optimize strategy
logging.warn("""
You can try our memory optimize feature to save your memory usage:
# create a build_strategy variable to set memory optimize option
build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
# pass the build_strategy to with_data_parallel API
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
!!! Memory optimize is our experimental feature !!!
some variables may be removed/reused internal to save memory usage,
in order to fetch the right value of the fetch_list, please set the
persistable property to true for each variable in fetch_list
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
""")
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
...@@ -211,15 +245,6 @@ class CompiledProgram(object): ...@@ -211,15 +245,6 @@ class CompiledProgram(object):
else: else:
self._exec_strategy.num_threads = len(self._places) * 2 self._exec_strategy.num_threads = len(self._places) * 2
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# memory_optimize and enable_inplace default are True, but we can disable them on purpose
if self._program and self._program._is_mem_optimized:
self._build_strategy.memory_optimize = False
if self._program and self._program._is_mem_optimized:
self._build_strategy.enable_inplace = False
# TODO(wuyi): trainer endpoings should be passed in through # TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx. # build_strategy, not program.xxx.
if self._program and self._build_strategy.num_trainers > 1 and \ if self._program and self._build_strategy.num_trainers > 1 and \
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import logging
import os import os
import multiprocessing import multiprocessing
import numpy as np import numpy as np
...@@ -449,6 +450,36 @@ class Executor(object): ...@@ -449,6 +450,36 @@ class Executor(object):
return as_numpy(arr) return as_numpy(arr)
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def _check_fetch_vars_persistable(self, program, fetch_list):
for var in fetch_list:
if isinstance(var, Variable):
persistable = var.persistable
else:
block_num = program.desc.num_blocks()
persistable = None
var_name = cpt.to_bytes(var)
for i in six.moves.range(block_num):
var_desc = program.desc.block(i).find_var(var_name)
if var_desc:
persistable = var_desc.persistable()
break
assert persistable is not None, "Variable {} is not found".format(
var)
if not persistable:
logging.warn("""
Detect that memory optimize or inplace is enabled, but the some variables in the fetch
list is not persistable, you may get wrong fetched value, or an exeception may be thrown
about cannot find variable of the fetch list.
TO FIX this:
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
""")
def run(self, def run(self,
program=None, program=None,
feed=None, feed=None,
...@@ -532,6 +563,11 @@ class Executor(object): ...@@ -532,6 +563,11 @@ class Executor(object):
scope=scope, scope=scope,
return_numpy=return_numpy, return_numpy=return_numpy,
use_program_cache=use_program_cache) use_program_cache=use_program_cache)
else:
if fetch_list and program._is_data_parallel and program._program and (
program._build_strategy.memory_optimize or
program._build_strategy.enable_inplace):
self._check_fetch_vars_persistable(program._program, fetch_list)
program._compile(scope, self.place) program._compile(scope, self.place)
if program._is_data_parallel: if program._is_data_parallel:
......
...@@ -58,12 +58,15 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -58,12 +58,15 @@ class TestParallelExecutorBase(unittest.TestCase):
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = 1 # Fix random seed startup.random_seed = 1 # Fix random seed
main.random_seed = 1 main.random_seed = 1
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
if seed is not None: if seed is not None:
startup.random_seed = seed startup.random_seed = seed
main.random_seed = seed main.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
loss.persistable = True
if optimizer: if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册