From 2739096eec359d1060e37dad114183cc2e1cb376 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 27 Jan 2019 16:46:49 +0800 Subject: [PATCH] compatibable with python side mem_opt --- paddle/fluid/framework/details/CMakeLists.txt | 6 +- .../fluid/framework/details/build_strategy.cc | 29 ++++ .../framework/details/graph_print_pass.cc | 125 ++++++++++++++ .../framework/details/graph_print_pass.h | 66 ++++++++ .../details/graph_print_pass_test.cc | 79 +++++++++ .../fluid/framework/details/graph_test_base.h | 80 +++++++++ .../framework/details/inplace_op_pass.cc | 158 ++++++++++++++---- .../details/memory_optimize_pass_test.cc | 55 +----- .../details/multi_devices_graph_print_pass.h | 10 +- .../unittests/parallel_executor_test_base.py | 114 ++++++------- .../tests/unittests/test_ir_inplace_pass.py | 69 ++++++++ 11 files changed, 633 insertions(+), 158 deletions(-) create mode 100644 paddle/fluid/framework/details/graph_print_pass.cc create mode 100644 paddle/fluid/framework/details/graph_print_pass.h create mode 100644 paddle/fluid/framework/details/graph_print_pass_test.cc create mode 100644 paddle/fluid/framework/details/graph_test_base.h create mode 100644 python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index de81f6f671..c4e22615ba 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -51,7 +51,8 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass) -cc_library(inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info) +cc_library(graph_print_pass SRCS graph_print_pass.cc DEPS graph_helper pass) +cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info graph_print_pass) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) @@ -72,6 +73,7 @@ if (WITH_GPU) endif() cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass) +cc_test(graph_print_pass_test SRCS graph_print_pass_test.cc DEPS graph_print_pass framework_proto graph graph_helper op_registry pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) @@ -96,4 +98,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS multi_devices_graph_print_pass multi_devices_graph_check_pass fuse_elewise_add_act_pass multi_batch_merge_pass fuse_relu_depthwise_conv_pass - memory_optimize_pass lock_free_optimize_pass) + memory_optimize_pass lock_free_optimize_pass graph_print_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 0831772a96..38c03a2604 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" @@ -43,8 +44,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) : ir::PassBuilder(), strategy_(strategy) { if (strategy_.enable_inplace_) { + // before inplaced + // if (!strategy_.debug_graphviz_path_.empty()) { + // const std::string path = strategy_.debug_graphviz_path_ + + // "before_inplaced"; + // auto pass = AppendPass("graph_print_pass"); + // pass->Set(kGraphvizPath, new std::string(path)); + // } + AppendPass("inplace_pass"); + // after inplaced + // if (!strategy_.debug_graphviz_path_.empty()) { + // const std::string path = strategy_.debug_graphviz_path_ + + // "after_inplaced"; + // auto pass = AppendPass("graph_print_pass"); + // pass->Set(details::kGraphvizPath, new + // std::string(path)); + // } } + if (strategy_.enable_sequential_execution_) { AppendPass("sequential_execution_pass"); } @@ -189,6 +207,9 @@ std::unique_ptr BuildStrategy::Apply( pass->SetNotOwned("nccl_ctxs", nctx); #endif } else if (pass->Type() == "memory_optimize_pass") { + if (graph->Has(kAllOpDescs)) { + graph->Erase(kAllOpDescs); + } const std::vector *all_op_descs = new std::vector(main_program.Block(0).AllOps()); graph->Set>(kAllOpDescs, @@ -219,6 +240,9 @@ std::unique_ptr BuildStrategy::Apply( if (graph->Has(kAllOpDescs)) { graph->Erase(kAllOpDescs); } + if (!graph->Has(kGraphviz)) { + graph->Set(kGraphviz, new GraphvizNodes); + } graph->Set>( kAllOpDescs, new std::vector(main_program.Block(0).AllOps())); @@ -228,6 +252,10 @@ std::unique_ptr BuildStrategy::Apply( "GPU, skipped."; continue; } + } else if (pass->Type() == "graph_print_path") { + if (!graph->Has(kGraphviz)) { + graph->Set(kGraphviz, new GraphvizNodes); + } } graph = pass->Apply(std::move(graph)); } @@ -253,3 +281,4 @@ USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(inplace_pass); USE_PASS(lock_free_optimize_pass); +USE_PASS(graph_print_pass); diff --git a/paddle/fluid/framework/details/graph_print_pass.cc b/paddle/fluid/framework/details/graph_print_pass.cc new file mode 100644 index 0000000000..b0a87810db --- /dev/null +++ b/paddle/fluid/framework/details/graph_print_pass.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/graph_print_pass.h" +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +class GraphvizVar : public GraphvizNode { + public: + GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {} + friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) { + sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]" + << std::endl; + return sout; + } +}; + +class GraphvizOp : public GraphvizNode { + public: + GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {} + friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) { + sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name() + << "\", shape=rect]" << std::endl; + PADDLE_ENFORCE(op.stream_.rdbuf()->in_avail() != 0, + "No inputs outputs. Please call AddEdge first!"); + sout << op.stream_.str(); + return sout; + } + template + void AddEdge(const Callback& cb) { + std::string op_name = "op_" + std::to_string(id_); + for (auto var : node_->inputs) { + std::string var_name = "var_" + std::to_string(cb(var)); + stream_ << var_name << "->" << op_name << std::endl; + } + for (auto var : node_->outputs) { + std::string var_name = "var_" + std::to_string(cb(var)); + stream_ << op_name << "->" << var_name << std::endl; + } + } + + private: + std::ostringstream stream_; +}; + +template +std::vector FilterByNodeWrapper(const Container& con) { + std::vector ret; + for (auto& node : con) { + auto i = dynamic_cast(node.get()); + if (i != nullptr) ret.emplace_back(i); + } + return ret; +} + +std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( + const ir::Graph& graph) const { + // Convert to GraphvizNode format + auto& graphviz_nodes = graph.Get(kGraphviz); + graphviz_nodes.clear(); + std::unordered_map vars; + int var_id = 0; + int op_id = 0; + for (auto& node : graph.Nodes()) { + if (node->IsVar()) { + graphviz_nodes.emplace(new GraphvizVar(node, var_id)); + vars.emplace(std::make_pair(node, var_id++)); + } else if (node->IsOp()) { + graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); + } else { + PADDLE_THROW("Unknown op type"); + } + } + return vars; +} + +void SSAGraphPrinterImpl::Print(const ir::Graph& graph, + std::ostream& sout) const { + auto vars = ToGraphvizNode(graph); + auto& nodes = graph.Get(kGraphviz); + + sout << "digraph G {\n"; + for (auto& var : FilterByNodeWrapper(nodes)) { + sout << *var; + } + + for (auto& op : FilterByNodeWrapper(nodes)) { + op->AddEdge([&vars](ir::Node* var) { return vars.at(var); }); + sout << *op; + } + sout << "}\n"; +} + +std::unique_ptr SSAGraphPrintPass::ApplyImpl( + std::unique_ptr graph) const { + printer_.reset(new SSAGraphPrinterImpl()); + std::unique_ptr fout( + new std::ofstream(Get(kGraphvizPath))); + PADDLE_ENFORCE(fout->good() == true, "Failed to open file."); + + printer_->Print(*graph, *fout); + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass) + .RequirePassAttr(paddle::framework::details::kGraphvizPath); diff --git a/paddle/fluid/framework/details/graph_print_pass.h b/paddle/fluid/framework/details/graph_print_pass.h new file mode 100644 index 0000000000..10ff8c321b --- /dev/null +++ b/paddle/fluid/framework/details/graph_print_pass.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/details/multi_devices_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kGraphvizPath[] = "debug_graphviz_path"; +constexpr char kGraphviz[] = "graphviz"; + +class GraphvizNode { + public: + GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {} + virtual ~GraphvizNode() = default; + + protected: + ir::Node* node_; + int id_; +}; +class GraphvizNode; +typedef std::unordered_set> GraphvizNodes; + +class SSAGraphPrinter { + public: + virtual ~SSAGraphPrinter() {} + virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0; +}; + +class SSAGraphPrinterImpl : public SSAGraphPrinter { + public: + void Print(const ir::Graph& graph, std::ostream& sout) const override; + + private: + std::unordered_map ToGraphvizNode( + const ir::Graph& graph) const; +}; + +class SSAGraphPrintPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; + + private: + mutable std::unique_ptr printer_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/graph_print_pass_test.cc b/paddle/fluid/framework/details/graph_print_pass_test.cc new file mode 100644 index 0000000000..1149d1684e --- /dev/null +++ b/paddle/fluid/framework/details/graph_print_pass_test.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/graph_print_pass.h" +#include "paddle/fluid/framework/details/graph_test_base.h" + +REGISTER_OPERATOR(sum, paddle::framework::DummyOp, + paddle::framework::SumOpMaker); +REGISTER_OPERATOR(split, paddle::framework::DummyOp, + paddle::framework::SplitOpMaker); + +/* + a @ b + c + d @ e + */ + +using paddle::framework::ProgramDesc; +using paddle::framework::proto::VarType; + +inline static ProgramDesc FillProgramDesc() { + ProgramDesc prog; + prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR); + { + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a", "b"}); + op->SetOutput("Out", {"c"}); + } + { + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("split"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"d", "e"}); + } + { + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"d", "e"}); + op->SetOutput("Out", {"d"}); + } + return prog; +} + +namespace paddle { +namespace framework { +namespace details { + +TEST(SSAGraphPrinter, Normal) { + auto program = FillProgramDesc(); + std::unique_ptr graph(new ir::Graph(program)); + graph->Set(kGraphviz, new GraphvizNodes); + std::unique_ptr printer(new SSAGraphPrinterImpl); + + // redirect debug graph to a file. + constexpr char graph_path[] = "graph_print_pass.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(*graph, *fout); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/graph_test_base.h b/paddle/fluid/framework/details/graph_test_base.h new file mode 100644 index 0000000000..126959bcd8 --- /dev/null +++ b/paddle/fluid/framework/details/graph_test_base.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class DummyOp : public OperatorBase { + public: + DummyOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} +}; + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class AssignOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SplitOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", ""); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class DummyVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc& op_desc, BlockDesc* block) const override { + auto& inputs = op_desc.Input("X"); + auto type = block->Var(inputs.front())->GetType(); + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(type); + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index b08935e566..11ecc383b4 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -21,6 +21,7 @@ #include #include #include +#include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/op_info.h" @@ -76,42 +77,92 @@ namespace paddle { namespace framework { namespace details { -static inline ir::Node* GetNextInplacedOpOutput(ir::Node* var) { +static inline std::string NodeDebugString(ir::Node* var) { + std::ostringstream os; + if (var->IsCtrlVar()) { + os << "kControlDepVarName" + << " "; + } else if (var->IsOp()) { + os << "kOperation" + << " " << var->Name(); + PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name()); + } else if (var->IsVar()) { + os << "kVariable" + << " " << var->Name(); + PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name()); + } else { + PADDLE_THROW("Unknown node type."); + } + return os.str(); +} + +static inline std::string OpDebugString(ir::Node* var) { + ir::Node* op = var; + if (var->IsVar()) op = var->inputs.at(0); + std::stringstream os; + os << op->Name() << " : "; + + os << "Input "; + VLOG(3) << op->Name(); + for (auto* var : op->inputs) { + if (var->IsVar() && !var->IsCtrlVar()) { + PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), + "unmatched desc and var"); + // os << var << ":" << var->Name() << " "; + os << var->Name() << " "; + } + } + os << "Output "; + VLOG(3) << op->Name(); + for (auto* var : op->outputs) { + VLOG(3) << var; + VLOG(3) << var->Name(); + if (!var->IsVar()) { + VLOG(3) << "error"; + } + // VLOG(3) << var->Var()->Name(); + if (var->IsVar() && !var->IsCtrlVar()) { + PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), + "unmatched desc and var"); + // os << var << ":" << var->Name() << " "; + os << var->Name() << " "; + } + if (var->Name() == "fc_10.tmp_0") { + VLOG(3) << NodeDebugString(var); + } + } + return os.str(); +} + +static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { // if next op is inplaced, then return the output var // otherwise return nullptr PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); ir::Node* inplaced_var = nullptr; - // only has one output op can be inplaced - if (var->outputs.size() == 1 && var->outputs[0]->IsOp()) { - auto* op = var->outputs[0]; - for (auto* out_var : op->outputs) { - if (!out_var->IsVar() || out_var->IsCtrlVar() || - out_var->Var() == nullptr) - continue; - if (out_var->Name() == var->Name()) { - inplaced_var = out_var; - break; + for (auto* next_op : var->outputs) { + for (auto* output : next_op->outputs) { + if (output->IsVar() && !output->IsCtrlVar() && + output->Name() == var->Name()) { + inplaced_var = output; } } } return inplaced_var; } -static inline ir::Node* GetPrevInplacedOpInput(ir::Node* var) { +static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); - ir::Node* inplaced_var = nullptr; - if (var->inputs.size() == 1 && var->inputs[0]->IsOp()) { - auto* op = var->inputs[0]; - for (auto* in_var : op->inputs) { - if (!in_var->IsVar() || in_var->IsCtrlVar() || in_var->Var() == nullptr) - continue; - if (in_var->Name() == var->Name()) { - inplaced_var = in_var; - break; - } - } - } - return inplaced_var; + auto* prev_op = var->inputs.at(0); + auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(), + [&](ir::Node* node) { + if (node->IsVar() && !node->IsCtrlVar() && + node->Name() == var->Name()) { + return true; + } else { + return false; + } + }); + return input_it == prev_op->inputs.end() ? nullptr : *input_it; } template @@ -166,12 +217,22 @@ std::unique_ptr InplacePass::ApplyImpl( view_.Build(graph.get()); InitSSAGraphNodes(); + std::unique_ptr printer(new SSAGraphPrinterImpl); + for (auto* op : view_.AllOps()) { if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) continue; TryInplaceOpInputOutput(op, graph.get()); } graph->ResolveHazard(var_nodes_); + + constexpr char graph_path[] = "ir_graph_inplaced.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(*graph, *fout); + // for(auto* op : view_.AllOps()) { + // VLOG(3) << OpDebugString(op); + // } return graph; } @@ -179,7 +240,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var, const std::string& cache_var, const size_t& idx) const { for (size_t i = idx; i < view_.AllOps().size(); ++i) { - auto* op = view_.AllOps()[i]; + ir::Node* op = view_.AllOps()[i]; PADDLE_ENFORCE(op->IsOp() && op->Op()); auto* op_desc = op->Op(); op_desc->RenameInput(var, cache_var); @@ -203,14 +264,28 @@ void InplacePass::InplaceModifyVar(const std::string& var, // redirect the input to the latest version of cache_var for (auto* node : op->inputs) { if (node->Name() == var) { - ir::Node* cache_node = var_nodes_[cache_var].back(); + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + var_nodes_[cache_var].emplace_back(cache_node); + // swap node to cache_node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); + PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, + cache_node); + cache_node->inputs.emplace_back(prev_op); for (auto* next_op : node->outputs) { std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, cache_node); } + + // release unused var in graph. Because python side memory optimize + // may reused the var in same name, so we only clear the var node + // after current inplaced index. + graph->RemoveNode(node); + auto& nodes = var_nodes_.at(var); + nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); } } @@ -220,7 +295,6 @@ void InplacePass::InplaceModifyVar(const std::string& var, if (node->Name() == var) { ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); var_nodes_[cache_var].emplace_back(cache_node); - // swap node to cache node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); @@ -230,15 +304,14 @@ void InplacePass::InplaceModifyVar(const std::string& var, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, cache_node); } + + // release unsed var in graph + graph->RemoveNode(node); + auto& nodes = var_nodes_.at(var); + nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); } } } - - // release node of unused var in graph - for (auto* node : var_nodes_[var]) { - graph->RemoveNode(node); - } - var_nodes_.at(var).clear(); } void InplacePass::TryInplaceOpInputOutput(ir::Node* op, @@ -260,6 +333,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, auto& all_ops = view_.AllOps(); auto cursor = std::find(all_ops.begin(), all_ops.end(), op); size_t idx = std::distance(all_ops.begin(), cursor); + VLOG(3) << op->Name() << idx; for (auto& pair : in_to_outs) { auto& in_var_name = pair.first; @@ -286,6 +360,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, } VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), out_var_name, in_var_name); + // VLOG(3) << "Out " << OpDebugString(op); InplaceModifyDesc(out_var_name, in_var_name, idx); InplaceModifyVar(out_var_name, in_var_name, idx, graph); } @@ -319,7 +394,16 @@ ir::Node* GraphView::GetNodeByName(const std::string& name, } std::vector GraphView::PendingOpsOnVar(ir::Node* node) { - return node->outputs; + // get the pending ops depends on same var node. + // because node also maybe a inplaced variable, so need to backtrack all the + // previous inplaced vars. + std::vector pending_ops; + ir::Node* p = node; + while (p != nullptr) { + pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end()); + p = GetPrevCascadeInplacedVar(p); + } + return pending_ops; } void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); } @@ -354,14 +438,14 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { // get the ops with same output name while (out != nullptr) { out_var_set.emplace(out); - out = GetNextInplacedOpOutput(out); + out = GetNextCascadeInplacedVar(out); } // get ops with same input name ir::Node* in = in_var; while (in != nullptr) { in_var_set.emplace(in); - in = GetPrevInplacedOpInput(in); + in = GetPrevCascadeInplacedVar(in); } // find if there is path with control dep var connect the in_var_set and // out_var_set diff --git a/paddle/fluid/framework/details/memory_optimize_pass_test.cc b/paddle/fluid/framework/details/memory_optimize_pass_test.cc index cde78bc3b2..3d3dfa9359 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass_test.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass_test.cc @@ -18,57 +18,13 @@ #include #include "glog/logging.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/details/graph_test_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -namespace paddle { -namespace framework { - -class DummyOp : public OperatorBase { - public: - DummyOp(const std::string& type, const VariableNameMap& inputs, - const VariableNameMap& outputs, const AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const Scope& scope, - const platform::Place& place) const override {} -}; - -class SumOpMaker : public OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); - AddComment(""); - } -}; - -class AssignOpMaker : public OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); - AddComment(""); - } -}; - -class DummyVarTypeInference : public VarTypeInference { - public: - void operator()(const OpDesc& op_desc, BlockDesc* block) const override { - auto& inputs = op_desc.Input("X"); - auto type = block->Var(inputs.front())->GetType(); - auto out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetType(type); - } -}; - -} // namespace framework -} // namespace paddle - REGISTER_OPERATOR(sum, paddle::framework::DummyOp, paddle::framework::SumOpMaker, paddle::framework::DummyVarTypeInference); @@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() { return prog; } -template -inline static std::string DebugString(const Container& c) { - std::stringstream ss; - for (auto& item : c) { - ss << item << " "; - } - return ss.str(); -} - TEST(CFGGraph, IRGraph) { // prepare ir graph auto prog = FillProgramDesc(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.h b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h index b06c87a5c1..69cac8ad95 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h @@ -19,20 +19,12 @@ #include #include #include -#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/graph_print_pass.h" namespace paddle { namespace framework { namespace details { -constexpr char kGraphvizPath[] = "debug_graphviz_path"; - -class SSAGraphPrinter { - public: - virtual ~SSAGraphPrinter() {} - virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0; -}; - class GraphvizSSAGraphPrinter : public SSAGraphPrinter { public: void Print(const ir::Graph& graph, std::ostream& sout) const override; diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 5ef1d2cfa6..5e5e6033d8 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -40,7 +40,7 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, - use_ir_memory_optimize=False, + use_ir_memory_optimize=True, enable_inplace=True, fuse_elewise_add_act_ops=False, fuse_relu_depthwise_conv=False, @@ -61,64 +61,66 @@ class TestParallelExecutorBase(unittest.TestCase): main.random_seed = seed loss = method(use_feed=feed_dict is not None) - if optimizer: optimizer().minimize(loss) if memory_opt: fluid.memory_optimize(main) - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(startup) - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.allow_op_delay = allow_op_delay - if use_fast_executor: - exec_strategy.use_experimental_executor = True - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ - if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce - build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops - build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv - build_strategy.memory_optimize = use_ir_memory_optimize - build_strategy.enable_inplace = enable_inplace - build_strategy.enable_sequential_execution = enable_sequential_execution - if use_cuda and core.is_compiled_with_cuda(): - build_strategy.remove_unnecessary_lock = True - if use_parallel_executor: - binary = compiler.CompiledProgram(main).with_data_parallel( - loss_name=loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - else: - binary = compiler.CompiledProgram(main) - - if batch_size is not None: - batch_size *= fluid.core.get_cuda_device_count( - ) if use_cuda else int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - begin = time.time() - first_loss, = run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) - - for i in range(iter): - run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) - - last_loss, = run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) - end = time.time() - - if batch_size is not None: - print("%.4f Instance per second" % ( - (batch_size * iter + 2) / (end - begin))) - - avg_last_loss_val = np.array(last_loss).mean() - avg_first_loss_val = np.array(first_loss).mean() - if math.isnan(float(avg_last_loss_val)) or math.isnan( - float(avg_first_loss_val)): - sys.exit("got NaN loss, training failed.") - - print(first_loss, last_loss) - # self.assertGreater(first_loss[0], last_loss[0]) - return first_loss, last_loss + with open("program_model.txt", "w") as f: + f.write(str(main)) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.allow_op_delay = allow_op_delay + if use_fast_executor: + exec_strategy.use_experimental_executor = True + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ + if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv + build_strategy.memory_optimize = use_ir_memory_optimize + build_strategy.enable_inplace = enable_inplace + build_strategy.enable_sequential_execution = enable_sequential_execution + build_strategy.debug_graphviz_path = "debug_ir_graph_" + + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True + if use_parallel_executor: + binary = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + else: + binary = compiler.CompiledProgram(main) + + if batch_size is not None: + batch_size *= fluid.core.get_cuda_device_count( + ) if use_cuda else int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + begin = time.time() + first_loss, = run_executor( + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) + + for i in range(iter): + run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) + + last_loss, = run_executor( + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) + end = time.time() + + if batch_size is not None: + print("%.4f Instance per second" % ( + (batch_size * iter + 2) / (end - begin))) + + avg_last_loss_val = np.array(last_loss).mean() + avg_first_loss_val = np.array(first_loss).mean() + if math.isnan(float(avg_last_loss_val)) or math.isnan( + float(avg_first_loss_val)): + sys.exit("got NaN loss, training failed.") + + print(first_loss, last_loss) + # self.assertGreater(first_loss[0], last_loss[0]) + return first_loss, last_loss diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py new file mode 100644 index 0000000000..0c9cd99322 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import numpy as np +import paddle.fluid as fluid +from parallel_executor_test_base import TestParallelExecutorBase + + +def fc_with_batchnorm(use_feed): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + hidden = img + for _ in range(3): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestIrInplace(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + os.environ['CPU_NUM'] = str(4) + + def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace): + np.random.seed(5) + img = np.random.random(size=[32, 784]).astype(np.float32) + label = np.ones(shape=[32, 1], dtype='int64') + self.check_network_convergence( + fc_with_batchnorm, + feed_dict={"image": img, + "label": label}, + use_cuda=True, + memory_opt=False, # inplace is conflict with memory opt + use_ir_memory_optimize=ir_memory_optimize, + enable_inplace=enable_inplace) + + def test_fc_with_batchnorm(self, delta=1e-3): + loss00 = self._fc_with_batchnorm(False, False) + loss10 = self._fc_with_batchnorm(True, False) + loss01 = self._fc_with_batchnorm(False, True) + loss11 = self._fc_with_batchnorm(True, True) + self.assertAlmostEqual(loss00, loss10, delta=delta) + self.assertAlmostEqual(loss00, loss01, delta=delta) + self.assertAlmostEqual(loss00, loss11, delta=delta) -- GitLab