diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index f1ce744a93b73aa5f00554f93796663c8a698e80..c4d3d4756435b53301b91a34a5753916dfcfaee4 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_ cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) +cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) + cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) if(WITH_DISTRIBUTE) @@ -114,4 +116,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS fuse_relu_depthwise_conv_pass memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass - fuse_adam_op_pass fuse_sgd_op_pass) + fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 36720b8ad97c6cacfccc106066bcdbe20c39ff47..7c73bf0b6d37e1e659c2be20dc30d0b354bbb6bd 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { viz_pass->Set("graph_viz_path", new std::string(graph_path)); } + // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. + AppendPass("record_skip_memory_opt_vars_pass"); + if (strategy_.enable_sequential_execution_) { VLOG(10) << "Add sequential_execution_pass"; AppendPass("sequential_execution_pass"); @@ -320,3 +323,4 @@ USE_PASS(graph_to_program_pass); USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); +USE_PASS(record_skip_memory_opt_vars_pass); diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 84c9e4a379a5e07dc3a8e85409c804eebc390c73..5e44bb5a8369d42f39ae729a17df7ebbc16edfec 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -303,7 +303,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); - VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name; + VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name; + if (view_.InSkipSet(in_var_name)) { + VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name); + continue; + } + + if (view_.InSkipSet(out_var_name)) { + VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name); + continue; + } if (var_nodes_[in_var_name].back() != in_node) { VLOG(4) << "SKIP since " << in_var_name @@ -318,11 +327,15 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, << out_var_name << " are the same"; } else if (!NodeCanReused(in_node)) { can_replace = false; - VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused"; + VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused"; } else if (!NodeCanReused(out_node)) { can_replace = false; VLOG(4) << "SKIP: Output variable " << out_var_name << " cannot be reused"; + } else if (in_node->Var()->GetType() != out_node->Var()->GetType()) { + can_replace = false; + VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType() + << " does not match Output type : " << out_node->Var()->GetType(); } else if (details::NodeSize(*in_node->Var()) != details::NodeSize(*out_node->Var())) { can_replace = false; @@ -331,8 +344,8 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, if (!can_replace) continue; - // 2. there is no external pending op on the input node - // if (view_.PendingOpsOnVar(in_node).size() > 1) { + // 2. If the variable is the input of muliple ops, we need to make sure + // current op has dependecny on other ops use the same variable if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) { VLOG(4) << string::Sprintf( "Skiped pair %s => %s. %s input has external dependency." @@ -341,17 +354,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, continue; } - // 3. if output has been memory optimize by python(fluid.memory_optmize()). - // this candidate can not be inplaced. Will be deprecated in the future. - if (view_.InSkipSet(out_node->Name())) { - VLOG(4) << string::Sprintf( - "Skiped %s => %s reused previous memory block in python memory " - "optmize," - "it inplace may generate a circle", - out_var_name, in_var_name, op->Name()); - continue; - } - // Debug Interface. Which would be skipped by the pass. if (out_node->Name() == FLAGS_memory_optimize_debug) { VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" @@ -519,16 +521,22 @@ void GraphView::Build(ir::Graph* g) { // resolve data harzards depends on the var nodes in right order. TopoSort(g); + // fill the skip_set_ + PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars)); + auto& mem_opt_whitelist = g->Get(kMemOptSkipVars); + for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); + // 2. track the nodes which used by parameter server. // these node can not be inplaced, otherwise trainer // pserver can not find each other name. auto update_skip_set = [&](ir::Node* node) { for (auto& in : node->inputs) { - if (in->IsVar() && in->Var() != nullptr) dup_nodes_.emplace(in->Name()); + if (in->IsVar() && in->Var() != nullptr) { + skip_set_.emplace(in->Name()); + } } for (auto& out : node->outputs) { - if (out->IsVar() && out->Var() != nullptr) - dup_nodes_.emplace(out->Name()); + if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name()); } }; for (auto& node : g->Nodes()) { @@ -545,7 +553,7 @@ void GraphView::Build(ir::Graph* g) { const std::vector& GraphView::AllOps() { return ops_; } bool GraphView::InSkipSet(const std::string& var) const { - return dup_nodes_.count(var); + return skip_set_.count(var); } } // namespace details diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index fbec973ddaa7673601780810cfbbf8c1128af513..2cd6cbd1b0317c3ea301428f2537023b026e581e 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -57,7 +57,7 @@ class GraphView { private: std::vector ops_; - std::unordered_set dup_nodes_; // mem opt affect nodes + std::unordered_set skip_set_; // mem opt affect nodes std::map> adj_list_; std::unordered_map op_level_; }; diff --git a/paddle/fluid/framework/details/memory_optimize_helper.h b/paddle/fluid/framework/details/memory_optimize_helper.h index 65c7017d2d462976cf8cd4d7b5f660e279e12b6a..0a65ec051df1f676c2818b916d8e32b46b0d2e29 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/data_type.h" @@ -30,6 +31,11 @@ namespace paddle { namespace framework { namespace details { +/// this attribute is used to avoid some core variables removed/reused +/// in memory optimize related passes +constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@"; +typedef std::unordered_set MemOptSkipVars; + std::vector SortOpLikeDescOrder(const ir::Graph& graph); // NOTE(dzh): A ordered set for node reuse in memory optimize. diff --git a/paddle/fluid/framework/details/memory_optimize_pass.cc b/paddle/fluid/framework/details/memory_optimize_pass.cc index ddaef206028b16dd10c2beb57ce6bf30103a8d10..ef36f1038e27770498d66663a0051dbf8f559f93 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass.cc @@ -45,8 +45,7 @@ namespace framework { namespace details { void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const { - auto nodes = graph->Nodes(); - CollectSkipVarsSet(nodes); + CollectSkipVarsSet(graph); cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_->LiveVariableAnalysis(); @@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { } } -void MemoryOptimizePass::CollectSkipVarsSet( - const std::unordered_set& nodes) const { +void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const { + // fill skip_set_ + PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars)); + auto& mem_opt_whitelist = graph->Get(kMemOptSkipVars); + for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); + auto update_skip_set = [&](OpDesc* op_desc) { auto inputs = op_desc->InputArgumentNames(); auto outputs = op_desc->OutputArgumentNames(); skip_set_.insert(inputs.begin(), inputs.end()); skip_set_.insert(outputs.begin(), outputs.end()); }; + + auto nodes = graph->Nodes(); for (auto& op : nodes) { if (!op->IsOp() || op->Op() == nullptr) continue; auto* op_desc = op->Op(); diff --git a/paddle/fluid/framework/details/memory_optimize_pass.h b/paddle/fluid/framework/details/memory_optimize_pass.h index ce94890b3856fa6bf167b8a08c814f81e422c372..fa5b9b322da8fce53a4205daab96aa649e526335 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.h +++ b/paddle/fluid/framework/details/memory_optimize_pass.h @@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass { // 1. scan op with subblock and collect the output/input vars. // while, while_grad, conditional_block // 2. scan distributed ops and collect the output/input vars - void CollectSkipVarsSet(const std::unordered_set&) const; + // 3. op_role_vars + void CollectSkipVarsSet(ir::Graph* graph) const; private: // Reuse Node Pool, Owned. diff --git a/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc b/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..7cb2544ebbfbf42f5e3c014528c56bf17989292e --- /dev/null +++ b/paddle/fluid/framework/details/record_skip_memory_opt_vars_pass.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace details { + +class RecordSkipMemoryOptVarsPass : public ir::Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override { + PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars)); + graph->Set(kMemOptSkipVars, new MemOptSkipVars); + auto& skip_vars = graph->Get(kMemOptSkipVars); + + // NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename + // in memory optimize pass. + InsertOpRoleVarsToSkipVarSet(graph, &skip_vars); + } + + void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph, + MemOptSkipVars* skip_vars) const { + for (auto& node : graph->Nodes()) { + PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr."); + if (node->IsOp() && node->Op()) { + try { + auto op_role_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0); + for (size_t i = 0; i < op_role_vars.size(); i += 2) { + auto& g_name = op_role_vars[i + 1]; + skip_vars->insert(g_name); + } + } catch (boost::bad_get e) { + } + } + } + } +}; + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(record_skip_memory_opt_vars_pass, + paddle::framework::details::RecordSkipMemoryOptVarsPass); diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc index a9b3b889229ee46bf66063c8381bdd02c7229cbd..a2c213945d7d3c0c6f540d994873f633694eeee9 100644 --- a/paddle/fluid/framework/inplace_op_inference_test.cc +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -19,6 +19,7 @@ #include #include "gtest/gtest.h" #include "paddle/fluid/framework/details/inplace_op_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" @@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) { FakeSuccData(&prog); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); g = test_SingleOpInplaceInToOut(std::move(g)); auto op_node = GetNodeFromGraph(g.get(), "single_op"); @@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) { FakeNoInplaceData(&prog); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); g = test_SingleOpInplaceInToOut(std::move(g)); auto op_node = GetNodeFromGraph(g.get(), "single_op"); @@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024}); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); std::unique_ptr pass(new details::InplacePass()); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); @@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024}); std::unique_ptr g(new ir::Graph(prog)); + g->Set(details::kMemOptSkipVars, new std::unordered_set()); std::unique_ptr pass(new details::InplacePass()); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index f8ded9f94ecaf3df1e14aead60ae12abcf8c34a9..a2cbeebfb6c3ef2b8e4af7df51965f52d8cee80c 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" @@ -28,6 +29,7 @@ void BindConstValue(pybind11::module* m) { m->def("kControlDepVarName", [] { return framework::ir::Node::kControlDepVarName; }); m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; }); + m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; }); auto op_proto_and_checker_maker = m->def_submodule("op_proto_and_checker_maker"); diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index c69ccd507210f976c1cb8ad072928b96693a948d..798e488f5b0c55c9eabdc420baa7bb0380b2fdba 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -84,6 +84,12 @@ void BindGraph(py::module *m) { return self.Set(attr_name, new std::unordered_set(attr)); }) + .def("set", + [](Graph &self, const std::string &attr_name, + const std::unordered_set &attr) { + return self.Set(attr_name, + new std::unordered_set(attr)); + }) .def("erase", &Graph::Erase) .def("nodes", &Graph::Nodes, return_value_policy::reference) .def("create_var_node", diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index ac2a40a7c25f7c3ff0cc103647355da55d27fec3..624c9934d5392b57526edea68254ddf45bd79f4c 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing import os import six @@ -152,6 +153,39 @@ class CompiledProgram(object): else: self._places = None self._build_strategy.is_distribution = _is_pserver_mode(self._program) + + # FIXME(dzhwinter): enable_inplace should be after memory_optimize + # if turn on python memory optimize, turn off the inplace_pass. + # memory_optimize and enable_inplace default are True, but we can disable them on purpose + if self._program: + if self._program._is_mem_optimized: + self._build_strategy.memory_optimize = False + self._build_strategy.enable_inplace = False + elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace: + # remind the user to try our memmory optimize strategy + logging.warn(""" + You can try our memory optimize feature to save your memory usage: + # create a build_strategy variable to set memory optimize option + build_strategy = compiler.BuildStrategy() + build_strategy.enable_inplace = True + build_strategy.memory_optimize = True + + # pass the build_strategy to with_data_parallel API + compiled_prog = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + + !!! Memory optimize is our experimental feature !!! + some variables may be removed/reused internal to save memory usage, + in order to fetch the right value of the fetch_list, please set the + persistable property to true for each variable in fetch_list + + # Sample + conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) + # if you need to fetch conv1, then: + conv1.persistable = True + + """) + return self def with_inference_optimize(self, config): @@ -211,15 +245,6 @@ class CompiledProgram(object): else: self._exec_strategy.num_threads = len(self._places) * 2 - # FIXME(dzhwinter): enable_inplace should be after memory_optimize - # if turn on python memory optimize, turn off the inplace_pass. - # memory_optimize and enable_inplace default are True, but we can disable them on purpose - if self._program and self._program._is_mem_optimized: - self._build_strategy.memory_optimize = False - - if self._program and self._program._is_mem_optimized: - self._build_strategy.enable_inplace = False - # TODO(wuyi): trainer endpoings should be passed in through # build_strategy, not program.xxx. if self._program and self._build_strategy.num_trainers > 1 and \ diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e15197037e1d901855883919b02a1574b7bc9a29..3983ef09aa38de6643746f1853b5019fa4cb34c1 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -14,6 +14,7 @@ from __future__ import print_function +import logging import os import multiprocessing import numpy as np @@ -449,6 +450,36 @@ class Executor(object): return as_numpy(arr) return [arr[i] for i in range(len(arr))] + def _check_fetch_vars_persistable(self, program, fetch_list): + for var in fetch_list: + if isinstance(var, Variable): + persistable = var.persistable + else: + block_num = program.desc.num_blocks() + persistable = None + var_name = cpt.to_bytes(var) + for i in six.moves.range(block_num): + var_desc = program.desc.block(i).find_var(var_name) + if var_desc: + persistable = var_desc.persistable() + break + assert persistable is not None, "Variable {} is not found".format( + var) + + if not persistable: + logging.warn(""" + Detect that memory optimize or inplace is enabled, but the some variables in the fetch + list is not persistable, you may get wrong fetched value, or an exeception may be thrown + about cannot find variable of the fetch list. + + TO FIX this: + # Sample + conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) + # if you need to fetch conv1, then: + conv1.persistable = True + + """) + def run(self, program=None, feed=None, @@ -532,6 +563,11 @@ class Executor(object): scope=scope, return_numpy=return_numpy, use_program_cache=use_program_cache) + else: + if fetch_list and program._is_data_parallel and program._program and ( + program._build_strategy.memory_optimize or + program._build_strategy.enable_inplace): + self._check_fetch_vars_persistable(program._program, fetch_list) program._compile(scope, self.place) if program._is_data_parallel: diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 18ed02a72275437fa6106e57c0383e17647d9700..c9a5d033e496a359656d881ea7fc112e64cb1d1d 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -57,12 +57,15 @@ class TestParallelExecutorBase(unittest.TestCase): startup = fluid.Program() startup.random_seed = 1 # Fix random seed main.random_seed = 1 + with fluid.program_guard(main, startup): if seed is not None: startup.random_seed = seed main.random_seed = seed loss = method(use_feed=feed_dict is not None) + loss.persistable = True + if optimizer: optimizer().minimize(loss)