From 1202d3fc740b065590f93a6029ab2126ca9dafff Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 21 Apr 2019 10:47:15 -0500 Subject: [PATCH] Refine model gpu memory (#16993) * speedup gc and inplace softmax_with_cross_entropy_grad test=develop * refine models gpu mem Merge skip vars and warning messages of mem opt remove relu mem opt test=develop * follow comments test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 4 +- .../fluid/framework/details/build_strategy.cc | 4 ++ .../details/eager_deletion_op_handle.cc | 2 +- .../details/eager_deletion_op_handle.h | 12 +++- .../fast_threaded_ssa_graph_executor.cc | 48 ++++++++++---- .../fast_threaded_ssa_graph_executor.h | 4 ++ .../framework/details/inplace_op_pass.cc | 59 +++++++++++------ .../fluid/framework/details/inplace_op_pass.h | 2 +- .../details/memory_optimize_helper.h | 6 ++ .../framework/details/memory_optimize_pass.cc | 13 ++-- .../framework/details/memory_optimize_pass.h | 3 +- .../fluid/framework/details/op_handle_base.h | 11 ++++ .../record_skip_memory_opt_vars_pass.cc | 64 +++++++++++++++++++ paddle/fluid/framework/garbage_collector.h | 9 +++ .../framework/inplace_op_inference_test.cc | 5 ++ paddle/fluid/operators/affine_channel_op.cc | 43 ++++++++++++- paddle/fluid/operators/affine_channel_op.cu | 7 +- .../softmax_with_cross_entropy_op.cc | 15 ++++- .../softmax_with_cross_entropy_op.cu | 7 +- .../operators/softmax_with_cross_entropy_op.h | 7 +- paddle/fluid/platform/device_context.h | 2 + paddle/fluid/pybind/const_value.cc | 2 + paddle/fluid/pybind/ir.cc | 6 ++ python/paddle/fluid/compiler.py | 43 ++++++++++--- python/paddle/fluid/executor.py | 36 +++++++++++ .../unittests/parallel_executor_test_base.py | 3 + 26 files changed, 358 insertions(+), 59 deletions(-) create mode 100644 paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 51231b981b..ae89f03186 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_ cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) +cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) + cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) if(WITH_DISTRIBUTE) @@ -124,4 +126,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS fuse_relu_depthwise_conv_pass memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass - fuse_adam_op_pass fuse_sgd_op_pass) + fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 196603bbff..e5dc89ee69 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { viz_pass->Set("graph_viz_path", new std::string(graph_path)); } + // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. + AppendPass("record_skip_memory_opt_vars_pass"); + if (strategy_.enable_sequential_execution_) { VLOG(10) << "Add sequential_execution_pass"; AppendPass("sequential_execution_pass"); @@ -341,3 +344,4 @@ USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(expected_kernel_cache_pass); +USE_PASS(record_skip_memory_opt_vars_pass); diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index dbc90737f2..52e6d599eb 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -34,7 +34,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( AtomicReferenceCountMap *ref_cnts) : OpHandleBase(node), scope_(scope), - var_names_(var_names), + var_names_(var_names.begin(), var_names.end()), gc_(gc), ref_cnts_(ref_cnts) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.h b/paddle/fluid/framework/details/eager_deletion_op_handle.h index 64867afad5..6300b9173b 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.h +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.h @@ -15,7 +15,10 @@ #pragma once #include +#include #include +#include +#include #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h" @@ -37,6 +40,13 @@ class EagerDeletionOpHandle : public OpHandleBase { std::string Name() const override; + /** + * Currently, EagerDeletionOpHandle has the highest priority. + * This priority settings speed up gc 15% in Transformer + * V100 8-GPU model. + */ + Priority GetPriority() const override { return kHighest; } + protected: void RunImpl() override; @@ -44,7 +54,7 @@ class EagerDeletionOpHandle : public OpHandleBase { void ClearGarbages(std::deque> *garbages); const Scope *scope_; - std::unordered_set var_names_; + std::vector var_names_; GarbageCollector *gc_; // not own AtomicReferenceCountMap *ref_cnts_; // not own #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 3e805bd5b4..c69f148297 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include +#include #include #include #include @@ -131,32 +132,53 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( return fetches; } +bool FastThreadedSSAGraphExecutor::RunOp( + OpHandleBase *op, const std::shared_ptr> &complete_q, + size_t *complete) { + try { + if (LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_cuda_); + } + ++(*complete); + return true; + } catch (...) { + exception_.Catch(std::current_exception()); + --remaining_; + complete_q->Push(-1UL); + return false; + } +} + void FastThreadedSSAGraphExecutor::RunOpAsync( std::unordered_map> *op_deps, OpHandleBase *op, const std::shared_ptr> &complete_q) { ++remaining_; this->pool_.enqueue([=] { - OpHandleBase *op_to_run = op; + std::queue op_queue; + op_queue.push(op); + size_t complete = 0; - while (op_to_run != nullptr) { - try { - if (LIKELY(!strategy_.dry_run_)) { - op_to_run->Run(strategy_.use_cuda_); - } - ++complete; - } catch (...) { - exception_.Catch(std::current_exception()); - --remaining_; - complete_q->Push(-1UL); + while (!op_queue.empty()) { + OpHandleBase *op_to_run = op_queue.front(); + op_queue.pop(); + + if (!RunOp(op_to_run, complete_q, &complete)) { return; } + auto &outputs = op_to_run->Outputs(); op_to_run = nullptr; for (auto &output : outputs) { for (auto &pending_op : output->PendingOps()) { std::atomic &deps = op_deps->at(pending_op); - if (deps.fetch_sub(1) == 1) { // pending_op ready + if (deps.fetch_sub(1) != 1) continue; + + // NOTE(zjl): op with highest priority should run + // first without switching to another thread. + if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) { + op_queue.push(pending_op); + } else { if (op_to_run == nullptr) { op_to_run = pending_op; } else { @@ -165,6 +187,8 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( } } } + + if (op_to_run != nullptr) op_queue.push(op_to_run); } --remaining_; complete_q->Push(complete); diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index f6d5160e75..234da5b925 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -60,6 +60,10 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ::ThreadPool pool_; ::ThreadPool prepare_pool_; + bool RunOp(OpHandleBase *op, + const std::shared_ptr> &complete_q, + size_t *complete); + void RunOpAsync(std::unordered_map> *op_deps, OpHandleBase *op, const std::shared_ptr> &complete_q); diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 84c9e4a379..03b21ae0ae 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -78,6 +78,13 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT "elementwise_add", "elementwise_add_grad", }; + +// FIXME(zjl): Shapes of in-out of some ops are exactly the same, +// but the static size during compiling time would be wrong. +// Use a flag to indicate such ops. Please fix me when found a better way. +static const std::unordered_set kSameShapeOpWhiteSet{ // NOLINT + "reshape2" +}; // clang-format on namespace paddle { @@ -303,7 +310,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); - VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name; + VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name; + if (view_.InSkipSet(in_var_name)) { + VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name); + continue; + } + + if (view_.InSkipSet(out_var_name)) { + VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name); + continue; + } if (var_nodes_[in_var_name].back() != in_node) { VLOG(4) << "SKIP since " << in_var_name @@ -318,21 +334,26 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, << out_var_name << " are the same"; } else if (!NodeCanReused(in_node)) { can_replace = false; - VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused"; + VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused"; } else if (!NodeCanReused(out_node)) { can_replace = false; VLOG(4) << "SKIP: Output variable " << out_var_name << " cannot be reused"; + } else if (in_node->Var()->GetType() != out_node->Var()->GetType()) { + can_replace = false; + VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType() + << " does not match Output type : " << out_node->Var()->GetType(); } else if (details::NodeSize(*in_node->Var()) != - details::NodeSize(*out_node->Var())) { + details::NodeSize(*out_node->Var()) && + kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) { can_replace = false; VLOG(4) << "SKIP: Input and Output varialbe size not match"; } if (!can_replace) continue; - // 2. there is no external pending op on the input node - // if (view_.PendingOpsOnVar(in_node).size() > 1) { + // 2. If the variable is the input of muliple ops, we need to make sure + // current op has dependecny on other ops use the same variable if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) { VLOG(4) << string::Sprintf( "Skiped pair %s => %s. %s input has external dependency." @@ -341,17 +362,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, continue; } - // 3. if output has been memory optimize by python(fluid.memory_optmize()). - // this candidate can not be inplaced. Will be deprecated in the future. - if (view_.InSkipSet(out_node->Name())) { - VLOG(4) << string::Sprintf( - "Skiped %s => %s reused previous memory block in python memory " - "optmize," - "it inplace may generate a circle", - out_var_name, in_var_name, op->Name()); - continue; - } - // Debug Interface. Which would be skipped by the pass. if (out_node->Name() == FLAGS_memory_optimize_debug) { VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" @@ -424,6 +434,9 @@ void GraphView::TopoSort(ir::Graph* graph) { for (auto& node : nodes) { if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) { all_ops_checked = false; + LOG(WARNING) + << "Node " << node->Name() << " has not been checked. " + << "Maybe some passes have not handle node dependency rightly."; break; } } @@ -519,16 +532,22 @@ void GraphView::Build(ir::Graph* g) { // resolve data harzards depends on the var nodes in right order. TopoSort(g); + // fill the skip_set_ + PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars)); + auto& mem_opt_whitelist = g->Get(kMemOptSkipVars); + for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); + // 2. track the nodes which used by parameter server. // these node can not be inplaced, otherwise trainer // pserver can not find each other name. auto update_skip_set = [&](ir::Node* node) { for (auto& in : node->inputs) { - if (in->IsVar() && in->Var() != nullptr) dup_nodes_.emplace(in->Name()); + if (in->IsVar() && in->Var() != nullptr) { + skip_set_.emplace(in->Name()); + } } for (auto& out : node->outputs) { - if (out->IsVar() && out->Var() != nullptr) - dup_nodes_.emplace(out->Name()); + if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name()); } }; for (auto& node : g->Nodes()) { @@ -545,7 +564,7 @@ void GraphView::Build(ir::Graph* g) { const std::vector& GraphView::AllOps() { return ops_; } bool GraphView::InSkipSet(const std::string& var) const { - return dup_nodes_.count(var); + return skip_set_.count(var); } } // namespace details diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index fbec973dda..2cd6cbd1b0 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -57,7 +57,7 @@ class GraphView { private: std::vector ops_; - std::unordered_set dup_nodes_; // mem opt affect nodes + std::unordered_set skip_set_; // mem opt affect nodes std::map> adj_list_; std::unordered_map op_level_; }; diff --git a/paddle/fluid/framework/details/memory_optimize_helper.h b/paddle/fluid/framework/details/memory_optimize_helper.h index 65c7017d2d..0a65ec051d 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/data_type.h" @@ -30,6 +31,11 @@ namespace paddle { namespace framework { namespace details { +/// this attribute is used to avoid some core variables removed/reused +/// in memory optimize related passes +constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@"; +typedef std::unordered_set MemOptSkipVars; + std::vector SortOpLikeDescOrder(const ir::Graph& graph); // NOTE(dzh): A ordered set for node reuse in memory optimize. diff --git a/paddle/fluid/framework/details/memory_optimize_pass.cc b/paddle/fluid/framework/details/memory_optimize_pass.cc index ddaef20602..ef36f1038e 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass.cc @@ -45,8 +45,7 @@ namespace framework { namespace details { void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const { - auto nodes = graph->Nodes(); - CollectSkipVarsSet(nodes); + CollectSkipVarsSet(graph); cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_->LiveVariableAnalysis(); @@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { } } -void MemoryOptimizePass::CollectSkipVarsSet( - const std::unordered_set& nodes) const { +void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const { + // fill skip_set_ + PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars)); + auto& mem_opt_whitelist = graph->Get(kMemOptSkipVars); + for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); + auto update_skip_set = [&](OpDesc* op_desc) { auto inputs = op_desc->InputArgumentNames(); auto outputs = op_desc->OutputArgumentNames(); skip_set_.insert(inputs.begin(), inputs.end()); skip_set_.insert(outputs.begin(), outputs.end()); }; + + auto nodes = graph->Nodes(); for (auto& op : nodes) { if (!op->IsOp() || op->Op() == nullptr) continue; auto* op_desc = op->Op(); diff --git a/paddle/fluid/framework/details/memory_optimize_pass.h b/paddle/fluid/framework/details/memory_optimize_pass.h index ce94890b38..fa5b9b322d 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.h +++ b/paddle/fluid/framework/details/memory_optimize_pass.h @@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass { // 1. scan op with subblock and collect the output/input vars. // while, while_grad, conditional_block // 2. scan distributed ops and collect the output/input vars - void CollectSkipVarsSet(const std::unordered_set&) const; + // 3. op_role_vars + void CollectSkipVarsSet(ir::Graph* graph) const; private: // Reuse Node Pool, Owned. diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index e0aa352e95..647b238634 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -15,6 +15,8 @@ #pragma once #include #include +#include +#include #include #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/ir/node.h" @@ -31,6 +33,13 @@ constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@"; // It's responsible for populating necessary fields of ir::Node. class OpHandleBase { public: + /** + * NOTE(zjl): Some op should have higher priority than others. + * The higher priority op would run first without switching + * threads in Executor. + */ + enum Priority { kHighest = 0, kNormal = 1 }; + // Owned by `node`. No need to be deleted explicitly. explicit OpHandleBase(ir::Node *node) : node_(node) { node_->WrappedBy(this); @@ -40,6 +49,8 @@ class OpHandleBase { std::string DebugString() const; + virtual Priority GetPriority() const { return kNormal; } + virtual std::string Name() const = 0; void Run(bool use_cuda); diff --git a/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc b/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc new file mode 100644 index 0000000000..7cb2544ebb --- /dev/null +++ b/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace details { + +class RecordSkipMemoryOptVarsPass : public ir::Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override { + PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars)); + graph->Set(kMemOptSkipVars, new MemOptSkipVars); + auto& skip_vars = graph->Get(kMemOptSkipVars); + + // NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename + // in memory optimize pass. + InsertOpRoleVarsToSkipVarSet(graph, &skip_vars); + } + + void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph, + MemOptSkipVars* skip_vars) const { + for (auto& node : graph->Nodes()) { + PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr."); + if (node->IsOp() && node->Op()) { + try { + auto op_role_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0); + for (size_t i = 0; i < op_role_vars.size(); i += 2) { + auto& g_name = op_role_vars[i + 1]; + skip_vars->insert(g_name); + } + } catch (boost::bad_get e) { + } + } + } + } +}; + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(record_skip_memory_opt_vars_pass, + paddle::framework::details::RecordSkipMemoryOptVarsPass); diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index f0b504627a..6ce797bd96 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -107,6 +107,15 @@ void GarbageCollector::Add(Container &&objs) { template void GarbageCollector::Add(Container &&objs, Callback &&callback) { + // Special case when FLAGS_eager_delete_tensor_gb=0.0 + // It speeds up GC about 2~3%. + if (max_memory_size_ <= 1) { + callback(); + auto *container = new Container(std::move(objs)); + ClearCallback([container] { delete container; }); + return; + } + GarbageQueue *garbage_queue = nullptr; { std::lock_guard guard(mutex_); diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc index a9b3b88922..a2c213945d 100644 --- a/paddle/fluid/framework/inplace_op_inference_test.cc +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -19,6 +19,7 @@ #include #include "gtest/gtest.h" #include "paddle/fluid/framework/details/inplace_op_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" @@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) { FakeSuccData(&prog); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); g = test_SingleOpInplaceInToOut(std::move(g)); auto op_node = GetNodeFromGraph(g.get(), "single_op"); @@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) { FakeNoInplaceData(&prog); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); g = test_SingleOpInplaceInToOut(std::move(g)); auto op_node = GetNodeFromGraph(g.get(), "single_op"); @@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024}); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); std::unique_ptr pass(new details::InplacePass()); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); @@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024}); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); std::unique_ptr pass(new details::InplacePass()); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 27370a3c29..da06354143 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -115,6 +117,14 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ctx->GetInputDim("Scale")); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; class AffineChannelGradMaker : public framework::SingleGradOpDescMaker { @@ -217,7 +227,6 @@ class AffineChannelGradKernel : public framework::OpKernel { : dims[dims.size() - 1]; int HxW = x->numel() / N / C; - auto* x_d = x->data(); auto* dy_d = dy->data(); auto* scale_d = scale->data(); ConstEigenVectorArrayMap scale_e(scale_d, C); @@ -242,6 +251,7 @@ class AffineChannelGradKernel : public framework::OpKernel { } // compute dscale and dbias if (dscale && dbias) { + auto* x_d = x->data(); dy_d = dy->data(); for (int i = 0; i < N; i++) { ConstEigenArrayMap x_e(x_d, HxW, C); @@ -270,6 +280,7 @@ class AffineChannelGradKernel : public framework::OpKernel { } // compute dscale and dbias if (dscale && dbias) { + auto* x_d = x->data(); ConstEigenArrayMap x_e(x_d, C, num); dscale_e = (x_e * dy_e).rowwise().sum(); dbias_e = dy_e.rowwise().sum(); @@ -278,6 +289,33 @@ class AffineChannelGradKernel : public framework::OpKernel { } }; +class AffineChannelNoNeedBufferVarsInference + : public framework::NoNeedBufferVarsInference { + public: + using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; + + private: + inline bool HasInput(const std::string& name) const { + auto& inputs = Inputs(); + auto iter = inputs.find(name); + if (iter == inputs.end() || iter->second.empty()) { + return false; + } else { + return iter->second[0] != framework::kEmptyVarName; + } + } + + public: + std::unordered_set operator()() const { + if (!HasInput(framework::GradVarName("Scale")) && + !HasInput(framework::GradVarName("Bias"))) { + return {"X"}; + } else { + return {}; + } + } +}; + } // namespace operators } // namespace paddle @@ -286,7 +324,8 @@ using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp, ops::AffineChannelOpMaker, ops::AffineChannelGradMaker); -REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad); +REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad, + ops::AffineChannelNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel, ops::AffineChannelKernel); diff --git a/paddle/fluid/operators/affine_channel_op.cu b/paddle/fluid/operators/affine_channel_op.cu index c054fdb1ba..e1435c29d8 100644 --- a/paddle/fluid/operators/affine_channel_op.cu +++ b/paddle/fluid/operators/affine_channel_op.cu @@ -128,14 +128,13 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel { framework::StringToDataLayout(ctx.Attr("data_layout")); auto& dev_ctx = ctx.template device_context(); - auto dims = x->dims(); - const int num = x->numel(); + auto dims = dy->dims(); + const int num = dy->numel(); int N = dims[0]; int C = layout == framework::DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; int HxW = num / N / C; - const T* x_d = x->data(); const T* dy_d = dy->data(); const T* s_d = scale->data(); @@ -155,6 +154,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel { dy_d, s_d, nullptr, C, HxW, num, dx_d); } if (dscale && dbias) { + const T* x_d = x->data(); AffineChannelScaleBiasGradientCUDAKernel< T, block, framework::DataLayout::kNCHW><<>>( @@ -167,6 +167,7 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel { dy_d, s_d, nullptr, C, HxW, num, dx_d); } if (dscale && dbias) { + const T* x_d = x->data(); AffineChannelScaleBiasGradientCUDAKernel< T, block, framework::DataLayout::kNHWC><<>>( diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 7c024e50dd..d7718bda5c 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include +#include +#include +#include namespace paddle { namespace operators { @@ -225,6 +228,15 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { } }; +class SoftmaxWithCrossEntropyGradInplaceInference + : public framework::InplaceOpInference { + public: + std::unordered_map operator()( + const framework::OpDesc& op_desc) const { + return {{"Softmax", framework::GradVarName("Logits")}}; + } +}; + } // namespace operators } // namespace paddle @@ -233,7 +245,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker); REGISTER_OPERATOR(softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyOpGrad); + ops::SoftmaxWithCrossEntropyOpGrad, + ops::SoftmaxWithCrossEntropyGradInplaceInference); REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyKernel, ops::SoftmaxWithCrossEntropyKernel); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index d3b8538124..ed61fb38b5 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -454,8 +454,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { context.Input(framework::GradVarName("Loss"))->data(); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); - framework::TensorCopy(*context.Input("Softmax"), context.GetPlace(), - context.device_context(), logit_grad); + const Tensor* softmax = context.Input("Softmax"); + if (logit_grad != softmax) { + framework::TensorCopy(*softmax, context.GetPlace(), + context.device_context(), logit_grad); + } T* logit_grad_data = logit_grad->data(); int rank = logit_grad->dims().size(); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index 8cba960c76..7ef7c4f742 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -68,7 +68,12 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { const Tensor* labels = context.Input("Label"); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); - logit_grad->ShareDataWith(*context.Input("Softmax")); + + const Tensor* softmax = context.Input("Softmax"); + if (logit_grad != softmax) { + framework::TensorCopy(*softmax, context.GetPlace(), + context.device_context(), logit_grad); + } int rank = logit_grad->dims().size(); const int class_num = logit_grad->dims()[rank - 1]; diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 778f6613bd..a86fef33b4 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -158,6 +158,8 @@ class CudnnHolder { if (required_workspace_len > WorkspaceSize()) { ReallocateWorkspace(required_workspace_len); } + VLOG(2) << "Cudnn workspace size: " + << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; cudnn_func(WorkspacePtr()); } diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 71eeaf3b53..4f9885b583 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" @@ -33,6 +34,7 @@ void BindConstValue(pybind11::module* m) { m->def("kControlDepVarName", [] { return framework::ir::Node::kControlDepVarName; }); m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; }); + m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; }); auto op_proto_and_checker_maker = m->def_submodule("op_proto_and_checker_maker"); diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index c69ccd5072..798e488f5b 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -84,6 +84,12 @@ void BindGraph(py::module *m) { return self.Set(attr_name, new std::unordered_set(attr)); }) + .def("set", + [](Graph &self, const std::string &attr_name, + const std::unordered_set &attr) { + return self.Set(attr_name, + new std::unordered_set(attr)); + }) .def("erase", &Graph::Erase) .def("nodes", &Graph::Nodes, return_value_policy::reference) .def("create_var_node", diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index ac2a40a7c2..624c9934d5 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing import os import six @@ -152,6 +153,39 @@ class CompiledProgram(object): else: self._places = None self._build_strategy.is_distribution = _is_pserver_mode(self._program) + + # FIXME(dzhwinter): enable_inplace should be after memory_optimize + # if turn on python memory optimize, turn off the inplace_pass. + # memory_optimize and enable_inplace default are True, but we can disable them on purpose + if self._program: + if self._program._is_mem_optimized: + self._build_strategy.memory_optimize = False + self._build_strategy.enable_inplace = False + elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace: + # remind the user to try our memmory optimize strategy + logging.warn(""" + You can try our memory optimize feature to save your memory usage: + # create a build_strategy variable to set memory optimize option + build_strategy = compiler.BuildStrategy() + build_strategy.enable_inplace = True + build_strategy.memory_optimize = True + + # pass the build_strategy to with_data_parallel API + compiled_prog = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + + !!! Memory optimize is our experimental feature !!! + some variables may be removed/reused internal to save memory usage, + in order to fetch the right value of the fetch_list, please set the + persistable property to true for each variable in fetch_list + + # Sample + conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) + # if you need to fetch conv1, then: + conv1.persistable = True + + """) + return self def with_inference_optimize(self, config): @@ -211,15 +245,6 @@ class CompiledProgram(object): else: self._exec_strategy.num_threads = len(self._places) * 2 - # FIXME(dzhwinter): enable_inplace should be after memory_optimize - # if turn on python memory optimize, turn off the inplace_pass. - # memory_optimize and enable_inplace default are True, but we can disable them on purpose - if self._program and self._program._is_mem_optimized: - self._build_strategy.memory_optimize = False - - if self._program and self._program._is_mem_optimized: - self._build_strategy.enable_inplace = False - # TODO(wuyi): trainer endpoings should be passed in through # build_strategy, not program.xxx. if self._program and self._build_strategy.num_trainers > 1 and \ diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index fa8b49a021..0b9a23e676 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -14,6 +14,7 @@ from __future__ import print_function +import logging import os import multiprocessing import numpy as np @@ -449,6 +450,36 @@ class Executor(object): return as_numpy(arr) return [arr[i] for i in range(len(arr))] + def _check_fetch_vars_persistable(self, program, fetch_list): + for var in fetch_list: + if isinstance(var, Variable): + persistable = var.persistable + else: + block_num = program.desc.num_blocks() + persistable = None + var_name = cpt.to_bytes(var) + for i in six.moves.range(block_num): + var_desc = program.desc.block(i).find_var(var_name) + if var_desc: + persistable = var_desc.persistable() + break + assert persistable is not None, "Variable {} is not found".format( + var) + + if not persistable: + logging.warn(""" + Detect that memory optimize or inplace is enabled, but the some variables in the fetch + list is not persistable, you may get wrong fetched value, or an exeception may be thrown + about cannot find variable of the fetch list. + + TO FIX this: + # Sample + conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) + # if you need to fetch conv1, then: + conv1.persistable = True + + """) + def run(self, program=None, feed=None, @@ -532,6 +563,11 @@ class Executor(object): scope=scope, return_numpy=return_numpy, use_program_cache=use_program_cache) + else: + if fetch_list and program._is_data_parallel and program._program and ( + program._build_strategy.memory_optimize or + program._build_strategy.enable_inplace): + self._check_fetch_vars_persistable(program._program, fetch_list) program._compile(scope, self.place) if program._is_data_parallel: diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 723aafb171..b1391749c0 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -58,12 +58,15 @@ class TestParallelExecutorBase(unittest.TestCase): startup = fluid.Program() startup.random_seed = 1 # Fix random seed main.random_seed = 1 + with fluid.program_guard(main, startup): if seed is not None: startup.random_seed = seed main.random_seed = seed loss = method(use_feed=feed_dict is not None) + loss.persistable = True + if optimizer: optimizer().minimize(loss) -- GitLab