diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 66f11dedbaccd7febcd75fa7ade9c68b6c42022c..910318a49cea50fadd29b1427a4591abfa5d5a23 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) @@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry proto_desc) +cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index d5966ad5a97a97ec40c8a01d2d2c8ed5d7f90421..6621a59d37a670f7025507faeab5b9897794a72e 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass) +cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper) +cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass) +cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) @@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() -cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph) -cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass) +cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) +cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index ce5731a1f414e8ef6d8af22a3bb17109e82beb87..51ce9732722efa44d2489f5b77694094e58c8775 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" @@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("sequential_execution_pass"); } + // Add op fusion. + if (strategy.fuse_relu_depthwise_conv_) { + AppendPass("fuse_relu_depthwise_conv_pass"); + } + + // NOTE(dzhwinter): A note for automatical inplace. + // 1. modify program desc passes should put + // before inplace pass. + // 2. manually configured inplace should put + // before inplace_pass + + // Add automatically inplace. + if (strategy_.enable_inplace_) { + AppendPass("inplace_pass"); + } + // Add a graph viz pass to record a graph. if (!strategy_.debug_graphviz_path_.empty()) { auto viz_pass = AppendPass("graph_viz_pass"); @@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { viz_pass->Set("graph_viz_path", new std::string(graph_path)); } - // Add op fusion. - if (strategy.fuse_relu_depthwise_conv_) { - AppendPass("fuse_relu_depthwise_conv_pass"); - } if (strategy.fuse_elewise_add_act_ops_) { auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); // Add a graph viz pass to record a graph. @@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // A side-effect of that, memory optimize cannot forsee the fetched vars // , so fetchlist should be set persistable before call the Run interface. if (strategy.memory_optimize_) { - auto analysis_var_pass = AppendPass("analysis_var_pass"); + auto memory_optimize_pass = AppendPass("memory_optimize_pass"); } AppendMultiDevPass(strategy); @@ -186,8 +198,10 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("nccl_ctxs"); pass->SetNotOwned("nccl_ctxs", nctx); #endif - - } else if (pass->Type() == "analysis_var_pass") { + } else if (pass->Type() == "memory_optimize_pass") { + if (graph->Has(kAllOpDescs)) { + graph->Erase(kAllOpDescs); + } const std::vector *all_op_descs = new std::vector(main_program.Block(0).AllOps()); graph->Set>(kAllOpDescs, @@ -214,6 +228,13 @@ std::unique_ptr BuildStrategy::Apply( pass->Set>( kAllOpDescs, new std::vector(main_program.Block(0).AllOps())); + } else if (pass->Type() == "inplace_pass") { + if (graph->Has(kAllOpDescs)) { + graph->Erase(kAllOpDescs); + } + graph->Set>( + kAllOpDescs, + new std::vector(main_program.Block(0).AllOps())); } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { if (!use_cuda) { LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " @@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass); USE_PASS(dist_multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); -USE_PASS(analysis_var_pass); +USE_PASS(memory_optimize_pass); USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); +USE_PASS(inplace_pass); USE_PASS(lock_free_optimize_pass); USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index cd24a3175953bf323748bf0c7e3159761c13f0a9..e3e06a5614ddee0bea342bc3608691b7a32326cc 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -80,6 +80,11 @@ struct BuildStrategy { bool memory_early_delete_{false}; + // TODO(dzhwinter): + // make enable_inplace, memory_optimize_ + // memory_early_delete_ true by default + bool enable_inplace_{false}; + bool enable_sequential_execution_{false}; bool fuse_broadcast_op_{false}; diff --git a/paddle/fluid/framework/details/graph_test_base.h b/paddle/fluid/framework/details/graph_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..126959bcd80a4677f76b7cff677a82a319f7cfb3 --- /dev/null +++ b/paddle/fluid/framework/details/graph_test_base.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class DummyOp : public OperatorBase { + public: + DummyOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} +}; + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class AssignOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SplitOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", ""); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class DummyVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc& op_desc, BlockDesc* block) const override { + auto& inputs = op_desc.Input("X"); + auto type = block->Var(inputs.front())->GetType(); + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(type); + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..64368a5e8737b2484bda9b7dd52451b4d4f760ff --- /dev/null +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -0,0 +1,431 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/inplace_op_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/details/memory_optimize_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_info.h" + +// NOTE(dzhwinter): inplace means one op output variable reuse the input space. +// By our design, one operator only can read its input(const Variable), +// write its output(non-const Variable). If one operator is inplaced, means +// user have chance to write the space before reading happens. +// Especially when some optimize code writing style is applied. +// +// +// /* wrong case in operator */ +// /*In this case, a larger allocation is allocated, input content is lost*/ +// const Tensor* in = ctx.Input("In") +// Tensor* out = ctx.Output("Out"); +// auto* out_ptr = out->mutable_data(ctx.GetPlace()); +// out_ptr[0] = 0; // input contect is overwrited. + +// NOTE(dzhwinter): +// Only for backward compacity and stable. if enable_inplace_whitelist is turn +// on. +// only the ops in whitelist will be use inplace strategy. +// if not, all the op will be inplaced if it registered with InplaceClass +DEFINE_bool( + enable_inplace_whitelist, false, + "If this option turns on, only these op in whitelist can be inplaced." + "If it turns off, all of the running op can be candidate of inplaced op." + "Such as scale, elementwise_add" + "By default, it's turned on"); + +DECLARE_string(memory_optimize_debug); + +// clang-format off +const std::string kInplacedOpWhiteList[] = { // NOLINT + "sigmoid", + "exp", + "relu", + "tanh", + "sqrt", + "ceil", + "floor", + "reciprocal", + "relu6", + "soft_relu", + "hard_sigmoid", + "batch_norm", + "batch_norm_grad", + "sum", + "sum_grad", + "scale", + "reshape", + "elementwise_add", + "elementwise_add_grad", +}; +// clang-format on + +namespace paddle { +namespace framework { +namespace details { + +static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { + // if next op is inplaced, then return the output var + // otherwise return nullptr + PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); + ir::Node* inplaced_var = nullptr; + for (auto* next_op : var->outputs) { + for (auto* output : next_op->outputs) { + if (output->IsVar() && !output->IsCtrlVar() && + output->Name() == var->Name()) { + inplaced_var = output; + } + } + } + return inplaced_var; +} + +static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { + PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); + if (var->inputs.empty()) return nullptr; + auto* prev_op = var->inputs.at(0); + auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(), + [&](ir::Node* node) { + if (node->IsVar() && !node->IsCtrlVar() && + node->Name() == var->Name()) { + return true; + } else { + return false; + } + }); + return input_it == prev_op->inputs.end() ? nullptr : *input_it; +} + +InplacePass::InplacePass() : Pass() { + if (FLAGS_enable_inplace_whitelist) { + for (auto& s : kInplacedOpWhiteList) { + whitelist_.emplace(s); + } + } +} + +void InplacePass::InitSSAGraphNodes() const { + std::unordered_map> all_vars; + for (auto* op : view_.AllOps()) { + for (auto* node : op->inputs) { + if (!node->IsVar() || node->IsCtrlVar()) continue; + if (all_vars[node->Name()].count(node) == 0) { + all_vars[node->Name()].emplace(node); + var_nodes_[node->Name()].emplace_back(node); + } + } + for (auto* node : op->outputs) { + if (!node->IsVar() || node->IsCtrlVar()) continue; + if (all_vars[node->Name()].count(node) == 0) { + all_vars[node->Name()].emplace(node); + var_nodes_[node->Name()].emplace_back(node); + } + } + } +} + +std::unique_ptr InplacePass::ApplyImpl( + std::unique_ptr graph) const { + var_nodes_.clear(); + view_.Build(graph.get()); + InitSSAGraphNodes(); + + for (auto* op : view_.AllOps()) { + if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) + continue; + TryInplaceOpInputOutput(op, graph.get()); + } + graph->ResolveHazard(var_nodes_); + + return graph; +} + +void InplacePass::InplaceModifyDesc(const std::string& var, + const std::string& cache_var, + const size_t& idx) const { + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + ir::Node* op = view_.AllOps()[i]; + PADDLE_ENFORCE(op->IsOp() && op->Op()); + auto* op_desc = op->Op(); + op_desc->RenameInput(var, cache_var); + op_desc->RenameOutput(var, cache_var); + if (op_desc->Block()->HasVar(var)) op_desc->Block()->RemoveVar(var); + op_desc->Flush(); + } +} + +const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const { + PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && + var_nodes_[var].at(0)->Var() != nullptr); + std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); + var_desc->SetName(cache_var); + + SSANodePair swap_nodes; + + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + auto* op = view_.AllOps()[i]; + + // redirect the input to the latest version of cache_var + for (auto* node : op->inputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + + // swap node to cache_node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, + cache_node); + cache_node->inputs.emplace_back(prev_op); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + + swap_nodes.emplace_back(std::make_pair(node, cache_node)); + } + } + + // if we need to rename the output, + // always create a newer version of cache_var + for (auto* node : op->outputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + // swap node to cache node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + cache_node->inputs.emplace_back(op); + std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + + swap_nodes.emplace_back(std::make_pair(node, cache_node)); + } + } + } + + return swap_nodes; +} + +void InplacePass::CommitModify(const SSANodePair& swap_nodes, + ir::Graph* graph) const { + for (auto& pair : swap_nodes) { + auto *node = pair.first, *cache_node = pair.second; + const std::string var = node->Name(), cache_var = cache_node->Name(); + var_nodes_[cache_var].emplace_back(cache_node); + graph->RemoveNode(node); + auto& nodes = var_nodes_.at(var); + // release unused var in graph. Because python side memory optimize + // may reused the var in same name, so we only clear the var node + // after current inplaced index. + nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); + } +} + +void InplacePass::WithdrawModify(const SSANodePair& nodes, + ir::Graph* graph) const { + for (auto& pair : nodes) { + auto *node = pair.first, *cache_node = pair.second; + const std::string var = node->Name(), cache_var = cache_node->Name(); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, + node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, + node); + } + graph->RemoveNode(cache_node); + } +} + +void InplacePass::TryInplaceOpInputOutput(ir::Node* op, + ir::Graph* graph) const { + VLOG(4) << "Try to inplace op " << op->Name(); + PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, + "op_desc is nullptr"); + // some pre-requirments need to meet if the op want to inplaced. + + auto* op_desc = op->Op(); + auto& infer_inplace = + OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; + + // 1. infer_inplace_ is registered. + if (!static_cast(infer_inplace)) return; + PADDLE_ENFORCE(static_cast(infer_inplace), + "%s's infer_inplace has not been registered", op_desc->Type()); + + auto* block = op_desc->Block(); + auto in_to_outs = infer_inplace(*op_desc, block); + + auto& all_ops = view_.AllOps(); + auto cursor = std::find(all_ops.begin(), all_ops.end(), op); + size_t idx = std::distance(all_ops.begin(), cursor); + + for (auto& pair : in_to_outs) { + auto& in_var_name = pair.first; + auto& out_var_name = pair.second; + auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); + auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); + + // 2. there is no external pending op on the input node + if (view_.PendingOpsOnVar(in_node).size() > 1) { + VLOG(4) << string::Sprintf( + "Skiped pair %s => %s. %s input has external dependency." + "inplace such pair will overwrite the memory.", + out_var_name, in_var_name, op->Name()); + continue; + } + + // 3. if output has been memory optimize by python(fluid.memory_optmize()). + // this candidate can not be inplaced. Will be deprecated in the future. + if (view_.InSkipSet(out_node->Name())) { + VLOG(4) << string::Sprintf( + "Skiped %s => %s reused previous memory block in python memory " + "optmize," + "it inplace may generate a circle", + out_var_name, in_var_name, op->Name()); + continue; + } + + // Debug Interface. Which would be skipped by the pass. + if (out_node->Name() == FLAGS_memory_optimize_debug) { + VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" + << out_node->Name(); + continue; + } + + // NOTE(dzhwinter): + // two stage commit of inplaced process. if after inplace happens generate a + // circle, + // then withdraw the changes. Otherwise, safely add the node. + auto swap_nodes = + TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); + + if (!ir::HasCircle(*graph)) { + VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), + out_var_name, in_var_name); + InplaceModifyDesc(out_var_name, in_var_name, idx); + CommitModify(swap_nodes, graph); + } else { + VLOG(3) << string::Sprintf( + "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", + out_var_name, in_var_name, op->Name()); + WithdrawModify(swap_nodes, graph); + } + } +} + +ir::Node* GraphView::GetNodeByName(const std::string& name, + const std::vector& nodes) const { + // nodes should be op->inputs/outputs + // node in same node do have different name. + std::unordered_set nodes_in_op; + bool has_dup_node = + std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) { + if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) { + if (nodes_in_op.count(node->Name())) return true; + nodes_in_op.emplace(node->Name()); + } + return false; + }); + PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!"); + ir::Node* node = nullptr; + for (auto* it : nodes) { + if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue; + if (it->Name() == name) { + node = it; + break; + } + } + PADDLE_ENFORCE(node != nullptr, + string::Sprintf("Not found var %s in nodes!", name)); + return node; +} + +std::vector GraphView::PendingOpsOnVar(ir::Node* node) { + // get the pending ops depends on same var node. + // because node also maybe a inplaced variable, so need to backtrack all the + // previous inplaced vars. + std::vector pending_ops; + ir::Node* p = node; + while (p != nullptr) { + pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end()); + p = GetPrevCascadeInplacedVar(p); + } + return pending_ops; +} + +void GraphView::Build(ir::Graph* g) { + // track the var nodes in correct order. + // Because we insert some new created node. Which may have data race between + // nodes. + // resolve data harzards depends on the var nodes in right order. + ops_ = SortOpLikeDescOrder(*g); + + // 1. track the nodes which reused previous node in Python memory optimize. + // these node can not be inplaced, otherwise may generate a circle in graph. + std::unordered_set all_vars; + for (auto& node : g->Nodes()) { + if (node->IsVar()) continue; + for (auto& out : node->outputs) { + if (out->IsCtrlVar() || out->Var() == nullptr) continue; + if (all_vars.count(out->Name())) { + dup_nodes_.emplace(out->Name()); + } else { + all_vars.emplace(out->Name()); + } + } + } + + // 2. track the nodes which used by parameter server. + // these node can not be inplaced, otherwise trainer + // pserver can not find each other name. + for (auto& node : g->Nodes()) { + if (!node->IsOp()) continue; + if (node->Name() == "send") { + for (auto& in : node->inputs) { + dup_nodes_.emplace(in->Name()); + } + } + if (node->Name() == "recv") { + for (auto& out : node->outputs) { + dup_nodes_.emplace(out->Name()); + } + } + } +} + +const std::vector& GraphView::AllOps() { return ops_; } + +bool GraphView::InSkipSet(const std::string& var) const { + return dup_nodes_.count(var); +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass); diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1abcf1f279e225839d440ff9c6840ce9b8a6547f --- /dev/null +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may abtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class GraphView { + public: + GraphView() = default; + + void Build(ir::Graph* g); + + const std::vector& AllOps(); + + ir::Node* GetNodeByName(const std::string& name, + const std::vector& nodes) const; + + std::vector PendingOpsOnVar(ir::Node* var); + + // Will Deperated in the future. + // NOTE(dzhwinter) : + // 1. Python memory optimize will reuse + // memory based var name, so different op output may + // have the same variable name. enable inplace on such node + // will generate a circle in ssa graph. + // 2. DistributeTranspiler will use unique name to + // map the parameter and gradient, must be skipped. + bool InSkipSet(const std::string& var) const; + + private: + std::vector ops_; + std::unordered_set dup_nodes_; // mem opt affect nodes + std::map> adj_list_; +}; + +typedef std::vector> SSANodePair; +class InplacePass : public ir::Pass { + public: + InplacePass(); + + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; + + void InitSSAGraphNodes() const; + + private: + const SSANodePair TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const; + + void CommitModify(const SSANodePair&, ir::Graph* graph) const; + + void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const; + + void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, + const size_t& idx) const; + + void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const; + + mutable std::map> var_nodes_; + + mutable std::unordered_set whitelist_; + mutable GraphView view_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/memory_early_delete_pass.cc b/paddle/fluid/framework/details/memory_early_delete_pass.cc index 5906b7d57ce122520a4594f1528e00982eaa1a7f..69f8f705484450b0544291b19027eb174d7eeb8f 100644 --- a/paddle/fluid/framework/details/memory_early_delete_pass.cc +++ b/paddle/fluid/framework/details/memory_early_delete_pass.cc @@ -16,7 +16,7 @@ #include #include #include -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" diff --git a/paddle/fluid/framework/details/memory_reuse_types.cc b/paddle/fluid/framework/details/memory_optimize_helper.cc similarity index 66% rename from paddle/fluid/framework/details/memory_reuse_types.cc rename to paddle/fluid/framework/details/memory_optimize_helper.cc index 2b9ff518b9adcd366cc877998400a8bdc05fa033..b56ef021ef508a43aac082acbcfa6f543635203e 100644 --- a/paddle/fluid/framework/details/memory_reuse_types.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include #include +#include #include #include @@ -21,15 +23,17 @@ namespace paddle { namespace framework { namespace details { +size_t NodeSizeInBytes(const VarDesc& node) { + auto shape = node.GetShape(); + int size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + size_t type_size = SizeOfType(node.GetDataType()); + return type_size * std::abs(size); +} + size_t NodeSizeInBytes(ir::Node* n) { auto* desc = FindVarDescInBlock(n); - auto shape = desc->GetShape(); - size_t type_size = SizeOfType(desc->GetDataType()); - int size = 1; - for (auto& s : shape) { - size *= s; - } - return type_size * std::abs(size); + return NodeSizeInBytes(*desc); } std::string DebugStringImpl(VarDesc* var) { @@ -83,7 +87,7 @@ struct NodeComparator { } }; -void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { +void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) { PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(op->IsOp()); if (mark_table_.count(var->Name()) != 0) { @@ -119,11 +123,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { mark_table_[var->Name()] = it; } -int OrderedNodePairPool::GetIndex(ir::Node* var) { +int OrderedNodeList::GetIndex(ir::Node* var) { return std::distance(nodes_.begin(), mark_table_[var->Name()]); } -ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { +ir::Node* OrderedNodeList::NodeMatch(ir::Node* var) const { ir::Node* found_node = nullptr; NodeComparator compare_node; @@ -136,13 +140,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { return found_node; } -void OrderedNodePairPool::Erase(ir::Node* var) { - PADDLE_ENFORCE(mark_table_.count(var->Name())); - nodes_.erase(mark_table_[var->Name()]); - mark_table_.erase(var->Name()); +void OrderedNodeList::Erase(ir::Node* var) { Erase(var->Name()); } + +void OrderedNodeList::Erase(const std::string& var) { + PADDLE_ENFORCE(mark_table_.count(var)); + nodes_.erase(mark_table_[var]); + mark_table_.erase(var); } -std::string OrderedNodePairPool::ToString() const { +std::string OrderedNodeList::ToString() const { std::stringstream ss; for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { ss << DebugString(it->first) << " "; @@ -150,6 +156,43 @@ std::string OrderedNodePairPool::ToString() const { return ss.str(); } +bool NodeCanReused(ir::Node* node) { + if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false; + // auto* desc = node->Var(); + bool flag = NodeCanReused(*node->Var()); + for (auto* op : node->inputs) { + if (op->Op()->HasAttr("force_cpu")) { + // op output force generated in cpu, can not be reused. + flag &= framework::AttrReader(op->Op()->GetAttrMap()) + .Get("force_cpu") == 0; + } + } + return flag; +} + +bool NodeCanReused(const VarDesc& node) { + auto type = node.GetType(); + if (node.Persistable() || type != proto::VarType::LOD_TENSOR || + node.GetShape().empty()) { + return false; + } + // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad + std::string name = node.Name(); + if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@') + return false; + return true; +} + +bool OpHasSubBlock(OpDesc* desc) { + const AttributeMap& attrs = desc->GetAttrMap(); + for (auto& attr : attrs) { + if (attr.second.type() == typeid(BlockDesc*) || // NOLINT + attr.second.type() == typeid(std::vector)) // NOLINT + return true; + } + return false; +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/memory_reuse_types.h b/paddle/fluid/framework/details/memory_optimize_helper.h similarity index 69% rename from paddle/fluid/framework/details/memory_reuse_types.h rename to paddle/fluid/framework/details/memory_optimize_helper.h index 9a9c1d948e869016717fea9ff6b8236adfc29845..064183d61ea7386b6b45034c90fd7569a8647f60 100644 --- a/paddle/fluid/framework/details/memory_reuse_types.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -43,7 +43,7 @@ using GraphNodePool = std::vector< // For example, // node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], .. // O(1) insert, delete -class OrderedNodePairPool { +class OrderedNodeList { public: using NodePair = std::pair>; using Iter = typename std::list::iterator; @@ -53,8 +53,12 @@ class OrderedNodePairPool { void Erase(ir::Node* var); + void Erase(const std::string& var); + bool Has(ir::Node* var) { return mark_table_.count(var->Name()); } + bool Has(const std::string& var) { return mark_table_.count(var); } + ir::Node* NodeMatch(ir::Node* var) const; // map store non-const iterator, can not promise const int GetIndex(ir::Node* var); @@ -67,6 +71,11 @@ class OrderedNodePairPool { ConstIter end() const { return nodes_.end(); } size_t size() const { return nodes_.size(); } + void Clear() { + mark_table_.clear(); + nodes_.clear(); + } + private: // for searching. std::unordered_map mark_table_; @@ -74,14 +83,53 @@ class OrderedNodePairPool { std::list nodes_; }; +// valid a tensor can be reuse or not +bool NodeCanReused(ir::Node* node); + +// valid a tensor can be reuse or not. +bool NodeCanReused(const VarDesc& node); + +// check op has subblock or not +bool OpHasSubBlock(OpDesc* desc); + // node memory size in bytes size_t NodeSizeInBytes(ir::Node* n); +// node memory size in bytes +size_t NodeSizeInBytes(const VarDesc&); + std::string DebugString(ir::Node* var); -// std::string DebugString(VarDesc* var); VarDesc* FindVarDescInBlock(ir::Node* n); +template +class FilterVariableImpl { + public: + void operator()(const Container& nodes, Callback callback) { + for (auto* node : nodes) { + callback(node); + } + } +}; + +// filter var node for op->inputs/outputs +template +class FilterVariableImpl, Callback> { + public: + void operator()(const std::vector& nodes, Callback callback) { + for (auto* var : nodes) { + if (var->IsVar() && !var->IsCtrlVar()) { + callback(var); + } + } + } +}; + +template +void FilterVariables(const Container& nodes, Callback callback) { + FilterVariableImpl()(nodes, callback); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/memory_reuse_types_test.cc b/paddle/fluid/framework/details/memory_optimize_helper_test.cc similarity index 96% rename from paddle/fluid/framework/details/memory_reuse_types_test.cc rename to paddle/fluid/framework/details/memory_optimize_helper_test.cc index d2fabf5ce068e0f752b86c0d02b971f18fc65f01..f2b9baf14a34ace9cc860797280dbd519dfa4f2a 100644 --- a/paddle/fluid/framework/details/memory_reuse_types_test.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include #include #include @@ -27,8 +27,8 @@ namespace paddle { namespace framework { namespace details { -TEST(OrderedNodePairPool, Normal) { - OrderedNodePairPool pool; +TEST(OrderedNodeList, Normal) { + OrderedNodeList pool; std::vector> nodes; // clang-format off diff --git a/paddle/fluid/framework/details/analysis_var_pass.cc b/paddle/fluid/framework/details/memory_optimize_pass.cc similarity index 80% rename from paddle/fluid/framework/details/analysis_var_pass.cc rename to paddle/fluid/framework/details/memory_optimize_pass.cc index 223b9da3cfba33fc32d1334cddccb9f503bd0bef..33ca45668e86bdbe615b91366db7e286258dd7d6 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/analysis_var_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include #include #include @@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { op1->Outputs() == op2->Outputs(); } -template -class FilterVariableImpl { - public: - void operator()(const Container& nodes, Callback callback) { - for (auto* node : nodes) { - callback(node); - } - } -}; - -// filter var node for op->inputs/outputs -template -class FilterVariableImpl, Callback> { - public: - void operator()(const std::vector& nodes, Callback callback) { - for (auto* var : nodes) { - if (var->IsVar() && !var->IsCtrlVar()) { - callback(var); - } - } - } -}; - -template -void FilterVariables(const Container& nodes, Callback callback) { - FilterVariableImpl()(nodes, callback); -} - -std::unique_ptr AnalysisVarPass::ApplyImpl( +std::unique_ptr MemoryOptimizePass::ApplyImpl( std::unique_ptr graph) const { auto nodes = graph->Nodes(); auto subblock_vars = GetSubBlockVars(nodes); @@ -103,48 +75,53 @@ std::unique_ptr AnalysisVarPass::ApplyImpl( } for (auto& var : op->outputs) { - if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) { - ir::Node* cache = pool_.NodeMatch(var); - if (var->Name() == FLAGS_memory_optimize_debug) { - VLOG(3) << "start match var " << DebugString(var) << " of op " - << op->Name(); - VLOG(3) << pool_.ToString(); - VLOG(3) << "matched in pool : " - << ((cache == nullptr) ? "False" : "True"); - } - if (cache != nullptr) { - if (var->Name() == cache->Name()) { - VLOG(3) << "The same cache variable is cascade reused." - << var->Name() << " is re-filled to the pool after" - << "the reused op is finished. Current op can not " - << "replace it again. Skip this candidate."; - continue; - } + if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 || + skip_set_.count(var->Name())) + continue; + ir::Node* cache = pool_.NodeMatch(var); + + if (var->Name() == FLAGS_memory_optimize_debug) { + VLOG(3) << "start match var " << DebugString(var) << " of op " + << op->Name(); + VLOG(3) << pool_.ToString(); + VLOG(3) << "matched in pool : " + << ((cache == nullptr) ? "False" : "True"); + } - int node_idx_in_pool = pool_.GetIndex(cache); - VLOG(3) << string::Sprintf( - "!!! %s, %s => %s, cache idx %d, pool size %d", - std::to_string(reuse_id++), DebugString(var), DebugString(cache), - node_idx_in_pool, static_cast(pool_.size())); - // update CFG Graph on the fly. - // reused var maybe re-fill into the pool - cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); - // NOTE(dzhwinter): we need to both update the ProgramDesc - // and IR Graph. because op_desc/var_desc is used in CreateOp, - // CreateVar when running happens. But IR Graph - // define the dependence relationship between nodes. - RenameVarInGraphDesc(var->Name(), cache->Name(), idx); - RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get()); - - pool_.Erase(cache); + if (cache == nullptr) continue; + if (var->Name() == cache->Name()) { + VLOG(3) << "The same cache variable is cascade reused." << var->Name() + << " is re-filled to the pool after" + << "the reused op is finished. Current op can not " + << "replace it again. Skip this candidate."; + continue; + + int node_idx_in_pool = pool_.GetIndex(cache); + VLOG(3) << string::Sprintf( + "!!! %s, %s => %s, cache idx %d, pool size %d", + std::to_string(reuse_id++), DebugString(var), DebugString(cache), + node_idx_in_pool, static_cast(pool_.size())); + // update CFG Graph on the fly. + // reused var maybe re-fill into the pool + cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); + // NOTE(dzhwinter): we need to both update the ProgramDesc + // and IR Graph. because op_desc/var_desc is used in CreateOp, + // CreateVar when running happens. But IR Graph + // define the dependence relationship between nodes. + RenameVarInGraphDesc(var->Name(), cache->Name(), idx); + RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get()); + + pool_.Erase(cache); + } + // fill the pool + std::unordered_set unlived_vars; + for (auto var : cfg_->LiveIn(op)) { + if (cfg_->LiveOut(op).count(var) == 0) { + unlived_vars.emplace(var); } } - } - // fill the pool - for (auto var : cfg_->LiveIn(op)) { - if (cfg_->LiveOut(op).count(var) == 0) { + for (auto var : unlived_vars) { ir::Node* var_node = cfg_->GetNodeFromVarName(var, op); - if (var_node == nullptr) continue; if (NodeCanReused(var_node) && !pool_.Has(var_node)) { pool_.Insert(var_node, op); } @@ -177,7 +154,7 @@ std::unique_ptr AnalysisVarPass::ApplyImpl( return graph; } -void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { +void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { // conditional block, while op and their grad op auto* sub_block_desc = AttrReader(op_desc->GetAttrMap()).Get("sub_block"); @@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { } } -std::unordered_set AnalysisVarPass::GetSubBlockVars( +std::unordered_set MemoryOptimizePass::GetSubBlockVars( const std::unordered_set& nodes) const { std::unordered_set vars; for (auto& op : nodes) { @@ -263,9 +240,9 @@ std::unordered_set AnalysisVarPass::GetSubBlockVars( return vars; } -void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, - const std::string& cache_var, - size_t idx) const { +void MemoryOptimizePass::RenameVarInGraphDesc(const std::string& var, + const std::string& cache_var, + size_t idx) const { for (size_t i = idx; i < cfg_->Ops().size(); ++i) { auto* op = cfg_->Ops()[i]; PADDLE_ENFORCE(op->IsOp() && op->Op()); @@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, } } -void AnalysisVarPass::InitSSAGraphNodes() const { +void MemoryOptimizePass::InitSSAGraphNodes() const { std::unordered_map> all_vars; if (var_nodes_.empty()) { for (auto* op : cfg_->Ops()) { @@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const { } } -void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, - const std::string& cache_var, - size_t idx, ir::Graph* graph) const { +void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, + const std::string& cache_var, + size_t idx, + ir::Graph* graph) const { // if replace happens, we need to create a newer version cache_var // but use the same dims/data_type with var. PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && @@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, var_nodes_.at(var).clear(); } -bool AnalysisVarPass::NodeCanReused(ir::Node* node) const { - if (!node->IsVar() || node->IsCtrlVar()) return false; - auto* desc = node->Var(); - auto type = desc->GetType(); - if (desc->Persistable() || type != proto::VarType::LOD_TENSOR || - desc->GetShape().empty()) { - return false; - } - // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad - std::string name = node->Name(); - if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@') - return false; - if (skip_set_.count(name)) return false; - for (auto* op : node->inputs) { - if (op->Op()->HasAttr("force_cpu")) { - // op output force generated in cpu, can not be reused. - return framework::AttrReader(op->Op()->GetAttrMap()) - .Get("force_cpu") == 0; - } - } - return true; -} - -bool AnalysisVarPass::OpHasSubBlock(OpDesc* desc) const { - const AttributeMap& attrs = desc->GetAttrMap(); - for (auto& attr : attrs) { - if (attr.second.type() == typeid(BlockDesc*) || // NOLINT - attr.second.type() == typeid(std::vector)) // NOLINT - return true; - } - return false; -} - std::vector SortOpLikeDescOrder(const ir::Graph& graph) { PADDLE_ENFORCE(graph.Has(kAllOpDescs), "Graph has no attribute of kAllOpDescs."); @@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name, } // namespace framework } // namespace paddle -REGISTER_PASS(analysis_var_pass, paddle::framework::details::AnalysisVarPass) +REGISTER_PASS(memory_optimize_pass, + paddle::framework::details::MemoryOptimizePass) .RequireGraphAttr(paddle::framework::details::kGraphNodePool) .RequireGraphAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/details/analysis_var_pass.h b/paddle/fluid/framework/details/memory_optimize_pass.h similarity index 90% rename from paddle/fluid/framework/details/analysis_var_pass.h rename to paddle/fluid/framework/details/memory_optimize_pass.h index 144204beafb341351172c29e3b4cd41db49be6f9..b3e026e0bc1e222e82a22b343c86ddc87a967e8f 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.h +++ b/paddle/fluid/framework/details/memory_optimize_pass.h @@ -25,7 +25,7 @@ #include #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" @@ -35,12 +35,10 @@ namespace details { constexpr char kAllOpDescs[] = "all_op_descs"; std::vector SortOpLikeDescOrder(const ir::Graph& graph); -// sort op in bfs order -std::vector BFSSortGraphOps(const ir::Graph& graph); class ControlFlowGraph; -class AnalysisVarPass : public ir::Pass { +class MemoryOptimizePass : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; @@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass { ir::Graph* graph) const; void SubGraphOptimize(OpDesc* op_desc) const; - // valid a tensor can be reuse or not - bool NodeCanReused(ir::Node* node) const; // scan subblock and collect the output/input variables. std::unordered_set GetSubBlockVars( const std::unordered_set&) const; - // check op has subblock or not - bool OpHasSubBlock(OpDesc* desc) const; private: // Reuse Node Pool, Owned. - mutable OrderedNodePairPool pool_; + mutable OrderedNodeList pool_; // controlflow Graph mutable std::unique_ptr cfg_; // skip set diff --git a/paddle/fluid/framework/details/analysis_var_pass_test.cc b/paddle/fluid/framework/details/memory_optimize_pass_test.cc similarity index 90% rename from paddle/fluid/framework/details/analysis_var_pass_test.cc rename to paddle/fluid/framework/details/memory_optimize_pass_test.cc index 9bc4fd33f7058949ca60983ea666a21cb4877b3e..3d3dfa93594d496431f7cb60dceb26f20250fc16 100644 --- a/paddle/fluid/framework/details/analysis_var_pass_test.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass_test.cc @@ -12,63 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/analysis_var_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include #include #include #include "glog/logging.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/details/graph_test_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -namespace paddle { -namespace framework { - -class DummyOp : public OperatorBase { - public: - DummyOp(const std::string& type, const VariableNameMap& inputs, - const VariableNameMap& outputs, const AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const Scope& scope, - const platform::Place& place) const override {} -}; - -class SumOpMaker : public OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); - AddComment(""); - } -}; - -class AssignOpMaker : public OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); - AddComment(""); - } -}; - -class DummyVarTypeInference : public VarTypeInference { - public: - void operator()(const OpDesc& op_desc, BlockDesc* block) const override { - auto& inputs = op_desc.Input("X"); - auto type = block->Var(inputs.front())->GetType(); - auto out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetType(type); - } -}; - -} // namespace framework -} // namespace paddle - REGISTER_OPERATOR(sum, paddle::framework::DummyOp, paddle::framework::SumOpMaker, paddle::framework::DummyVarTypeInference); @@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() { return prog; } -template -inline static std::string DebugString(const Container& c) { - std::stringstream ss; - for (auto& item : c) { - ss << item << " "; - } - return ss.str(); -} - TEST(CFGGraph, IRGraph) { // prepare ir graph auto prog = FillProgramDesc(); diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index eea7e712f8f6e187cdceedce77cc76d1d4ca2101..0901e59f9786b43361e7a570f8c2a07be54c1ac2 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/grad_op_desc_maker.h" +#include "paddle/fluid/framework/inplace_op_inference.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" @@ -32,7 +33,8 @@ enum OpInfoFillType { kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, - kShapeInference = 4 + kShapeInference = 4, + kInplaceOpInference = 5 }; template @@ -48,8 +50,11 @@ struct OpInfoFillTypeID { ? kVarTypeInference : (std::is_base_of::value ? kShapeInference - : static_cast( - -1))))); + : (std::is_base_of< + InplaceOpInference, T>::value + ? kInplaceOpInference + : static_cast( + -1)))))); } }; @@ -139,6 +144,16 @@ struct OpInfoFiller { } }; +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->infer_inplace_ = [](const OpDesc& op_desc, BlockDesc* block) { + T infer; + return infer(op_desc, block); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/inplace_op_inference.h b/paddle/fluid/framework/inplace_op_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..03ab2a2b6c5dc07805fddddc3ac53f61e7b6a697 --- /dev/null +++ b/paddle/fluid/framework/inplace_op_inference.h @@ -0,0 +1,115 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/type_defs.h" + +namespace paddle { +namespace framework { + +/* + Inplace Inference for create In->Out pairs for inplaced operator. + If we specify a pair of corresponding names. For example, X->Out. + then Out will inplaced use X's memory. The base class will do + legality validation for both variables. +*/ +class InplaceOpInference { + public: + virtual ~InplaceOpInference() {} + virtual std::unordered_map operator()( + const OpDesc& op_desc, BlockDesc* block) const = 0; +}; + +class InplaceInToOut : public InplaceOpInference { + public: + std::unordered_map operator()( + const OpDesc& op_desc, BlockDesc* block) const { + std::unordered_map ret; + auto in_out_var_names_pair = this->Apply(op_desc, block); + for (auto& pair : in_out_var_names_pair) { + PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(), + string::Sprintf("op %s do not have input of %s!", + op_desc.Type(), pair.first)); + PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(), + string::Sprintf("op %s do not have output of %s!", + op_desc.Type(), pair.second)); + auto& in_name = op_desc.Input(pair.first).at(0); + auto& out_name = op_desc.Output(pair.second).at(0); + + auto in = block->FindRecursiveOrCreateVar(in_name); + auto out = block->FindRecursiveOrCreateVar(out_name); + if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name}); + } + return ret; + } + + protected: + virtual std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const = 0; + + bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const { + return in.Name() != out.Name() && details::NodeCanReused(in) && + details::NodeCanReused(out) && + details::NodeSizeInBytes(out) <= details::NodeSizeInBytes(in); + } +}; + +/* + Inplace In and Out for operator only have an Input and an Output. + For example, activation op. + */ +class SingleOpInplaceInToOut : public InplaceInToOut { + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + PADDLE_ENFORCE(!op_desc.InputNames().empty(), + "Op inputs must not be empty"); + PADDLE_ENFORCE(!op_desc.OutputNames().empty(), + "Op outputs must not be empty"); + auto x_name = op_desc.InputNames().at(0); + auto out_name = op_desc.OutputNames().at(0); + return std::unordered_map{{x_name, out_name}}; + } +}; + +/* + Gradient op. Inplace output use it's Input. + For example, Input@Grad->Input reuse strategy. + */ +class GradOpInplaceInToOut : public InplaceInToOut { + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + std::unordered_map ret; + std::unordered_set output_names(op_desc.OutputNames().begin(), + op_desc.OutputNames().end()); + for (auto& input_name : op_desc.InputNames()) { + if (output_names.count(GradVarName(input_name))) { + ret.insert({input_name, GradVarName(input_name)}); + } + } + return ret; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..121f648a5f04ae65560ae8d04042e40df61aad50 --- /dev/null +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -0,0 +1,287 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace framework { + +class NOP : public OperatorBase { + public: + NOP(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} +}; + +class SingleOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SingleGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("single_op_grad"); + op->SetInput("Out", OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +class SingleOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->HasInput("X"); + ctx->HasOutput("Out"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class SingleGradOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->HasInput(framework::GradVarName("Out")); + ctx->HasOutput(framework::GradVarName("X")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } +}; + +class MultiOutOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddInput("Y", "").AsDuplicable(); + AddInput("Z", "").AsDuplicable(); + AddOutput("Out", ""); + AddOutput("YOut", ""); + AddOutput("ZOut", ""); + AddOutput("NotReuseOut", ""); + AddComment(""); + } +}; + +class MultiOutShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->ShareDim("X", "Out"); + ctx->ShareDim("Y", "YOut"); + ctx->ShareDim("Z", "ZOut"); + } +}; + +class MultiGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("multi_out_grad"); + op->SetInput("X", Input("X")); + op->SetOutput(framework::GradVarName("Y"), OutputGrad("YOut")); + op->SetOutput(framework::GradVarName("X"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("Z"), OutputGrad("ZOut")); + return std::unique_ptr(op); + } +}; + +class MultiOutGradShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Y"), + ctx->GetInputDim(framework::GradVarName("YOut"))); + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + ctx->SetOutputDim(framework::GradVarName("Z"), + ctx->GetInputDim(framework::GradVarName("ZOut"))); + } +}; + +class MultiOutInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + return std::unordered_map{ + {"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"}, + }; + } +}; + +class MultiOutGradInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + return std::unordered_map{ + {framework::GradVarName("YOut"), framework::GradVarName("Y")}, + {framework::GradVarName("Out"), framework::GradVarName("X")}, + {framework::GradVarName("ZOut"), framework::GradVarName("Z")}, + }; + } +}; + +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +REGISTER_OPERATOR(single_op, f::NOP, f::SingleOpMaker, f::SingleGradOpMaker, + f::SingleOpInplaceInToOut, f::SingleOpShapeInference); +REGISTER_OPERATOR(single_op_grad, f::NOP, f::SingleOpInplaceInToOut, + f::SingleGradOpShapeInference); +REGISTER_OPERATOR(multi_out_op, f::NOP, f::MultiOutOpMaker, f::MultiGradOpMaker, + f::MultiOutInplaceInToOut, f::MultiOutShapeInference); +REGISTER_OPERATOR(multi_out_grad, f::NOP, f::MultiOutGradInplaceInToOut, + f::MultiOutGradShapeInference); + +namespace paddle { +namespace framework { + +TEST(InferInplace, SingleOpInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("single_op"); + op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); + op->SetOutput("Out", {"test2_out"}); + + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 64}); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_out"); + prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 1ul); + auto it = in_to_outs.begin(); + EXPECT_EQ(it->first, "test2_a"); + EXPECT_EQ(it->second, "test2_out"); +} + +TEST(InferInplace, SingleGradOpInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("single_op_grad"); + op->SetInput(GradVarName("Out"), {"test2_out"}); + op->SetOutput(GradVarName("X"), {"test2_a", "test2_b", "test2_c"}); + + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_out"); + prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 1ul); + auto it = in_to_outs.begin(); + EXPECT_EQ(it->first, "test2_out"); + EXPECT_EQ(it->second, "test2_a"); +} + +TEST(InferInplace, MultiOutInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("multi_out_op"); + op->SetInput("X", {"a0", "a1"}); + op->SetInput("Y", {"b0"}); + op->SetInput("Z", {"c0", "c1"}); + op->SetOutput("Out", {"o0"}); + op->SetOutput("YOut", {"y0"}); + op->SetOutput("ZOut", {"z0"}); + + prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("o0"); + prog.MutableBlock(0)->Var("y0"); + prog.MutableBlock(0)->Var("z0"); + prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 3ul); + std::unordered_map expects = { + {"a0", "o0"}, {"b0", "y0"}, {"c0", "z0"}, + }; + EXPECT_TRUE(expects == in_to_outs); +} + +TEST(InferInplace, MultiGradInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("multi_out_grad"); + op->SetInput(GradVarName("Out"), {"o0"}); + op->SetInput(GradVarName("YOut"), {"y0"}); + op->SetInput(GradVarName("ZOut"), {"z0"}); + op->SetOutput(GradVarName("X"), {"a0", "a1"}); + op->SetOutput(GradVarName("Y"), {"b0"}); + op->SetOutput(GradVarName("Z"), {"c0", "c1"}); + + prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("o0"); + prog.MutableBlock(0)->Var("y0"); + prog.MutableBlock(0)->Var("z0"); + prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 3ul); + std::unordered_map expects = { + {"o0", "a0"}, {"y0", "b0"}, {"z0", "c0"}, + }; + EXPECT_TRUE(expects == in_to_outs); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 8de93cf285e4bf34c2d2bf425fa5f3459704b3d6..22d4c0a91cc1638264a8c57aa2841ff4e65a1400 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -52,16 +52,29 @@ bool HasCircleHelper( ir::Node *node, const std::map> &adj_list, std::unordered_set *visited, - std::unordered_set *in_trace) { + std::unordered_set *in_trace, + std::vector> *circles) { if (visited->find(node) == visited->end()) { visited->insert(node); in_trace->insert(node); for (ir::Node *in : adj_list.at(node)) { if (visited->find(in) == visited->end() && - HasCircleHelper(in, adj_list, visited, in_trace)) { + HasCircleHelper(in, adj_list, visited, in_trace, circles)) { return true; } else if (in_trace->find(in) != in_trace->end()) { + if (circles != nullptr) { + std::vector circle; + circle.emplace_back(in); + ir::Node *p = in; + for (auto &adj : adj_list.at(p)) { + if (in_trace->count(adj)) { + circle.emplace_back(adj); + p = adj; + } + } + circles->emplace_back(circle); + } return true; } } @@ -71,11 +84,12 @@ bool HasCircleHelper( } bool HasCircleInternal( - const std::map> &adj_list) { + const std::map> &adj_list, + std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { - if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { + if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) { return true; } } @@ -84,13 +98,18 @@ bool HasCircleInternal( } // namespace bool HasCircle(const Graph &graph) { - return HasCircleInternal(BuildOperationAdjList(graph)); + return HasCircleInternal(BuildOperationAdjList(graph), nullptr); +} + +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles) { + return HasCircleInternal(BuildOperationAdjList(graph), circles); } std::vector TopologySortOperations(const Graph &graph) { std::map> adj_list = BuildOperationAdjList(graph); - PADDLE_ENFORCE(!HasCircleInternal(adj_list)); + PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index fba4936f2c5c971f6c63a452ec4480ff091db25c..214de9ec7d85aee6021b18866295777e317aa79d 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -28,6 +28,11 @@ namespace ir { // Test if the graph contains circle. bool HasCircle(const Graph &graph); +// Find All Circles for debugging, +// store all subgraph in circles. +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles); + size_t GraphNum(const Graph &graph); // Topology Sort the operations in the graph from inputs to outputs. diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc index 260a73ae763bd2cdea9948e4d928377a7c718dda..d8973d5aeda1a2e0650a506b4c916b4346f01e2d 100644 --- a/paddle/fluid/framework/ir/graph_helper_test.cc +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) { // v4->outputs.push_back(o5); } +TEST(GraphHelperTest, Circles) { + ProgramDesc prog; + + Graph g(prog); + BuildCircleGraph(&g); + + std::vector> circles; + ASSERT_TRUE(FindCircleSubGraph(g, &circles)); + ASSERT_EQ(circles.size(), 1UL); +} + TEST(GraphHelperTest, GraphNum) { ProgramDesc prog; diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 19e5c2c73eac74dee030a4f7820531800f737e4e..4b55bd0703eee399cd841f90ea0b18d8fbdc67e8 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -38,6 +38,7 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + InferInplaceOpFN infer_inplace_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 938e2024c3359c2acd65a1aa4af875a8350e4c58..d02c699b979d7693bd83fd43fc73f7e0aeddb0cc 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -57,5 +57,8 @@ using InferVarTypeFN = using InferShapeFN = std::function; +using InplacePair = std::unordered_map; +using InferInplaceOpFN = std::function; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/utils/benchmark_tester.cc b/paddle/fluid/inference/utils/benchmark_tester.cc index 80763160df3adfd8c34e66bc7a5370808b349e76..0c48c2db9b691ae8cf587f2729c2789d4ce2dbe1 100644 --- a/paddle/fluid/inference/utils/benchmark_tester.cc +++ b/paddle/fluid/inference/utils/benchmark_tester.cc @@ -34,6 +34,6 @@ TEST(Benchmark, PersistToFile) { benchmark.SetLatency(220); benchmark.PersistToFile("1.log"); - benchmark.PersistToFile("1.log"); - benchmark.PersistToFile("1.log"); + benchmark.PersistToFile("2.log"); + benchmark.PersistToFile("3.log"); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 7ec9d2fed53c9c73952db7dcdfc2d8e634f3f84e..189db2317d0544014d9c74e0fd5e9ead54925b9c 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -547,12 +547,14 @@ namespace ops = paddle::operators; __macro(Swish, swish); \ __macro(ThresholdedRelu, thresholded_relu); -#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::operators::ActivationOpInferVarType, \ - ::paddle::operators::OP_NAME##GradMaker); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) +#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::ActivationOpInferVarType, \ + ::paddle::operators::OP_NAME##GradMaker, \ + ::paddle::framework::SingleOpInplaceInToOut); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \ + ::paddle::framework::SingleOpInplaceInToOut) #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 0736bd4d20eb60df4a1cb23aeec92dbe7f7495bd..feac4125381bd897dac89943af44850012e4761d 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -604,13 +604,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { } }; +class BatchNormInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}, + }; + return inplace_in_to_out; + } +}; + +class BatchNormGradInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + // Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C] + {framework::GradVarName("Y"), framework::GradVarName("X")}, + {"SavedMean", framework::GradVarName("Scale")}, + {"SavedVariance", framework::GradVarName("Bias")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); -REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, + ops::BatchNormInplaceInToOut); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, + ops::BatchNormGradInplaceInToOut); REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 7e789cd8d9143164c2346b067855eb904e00075f..c6c658236c235f0a6767924026b0a7610071e918 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -18,6 +18,7 @@ namespace ops = paddle::operators; REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out", "X"); + REGISTER_OP_CPU_KERNEL( elementwise_add, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index fd2a98cb45f446e80a4be1b50e94ee611cd23e62..d04bb8f338a80946e8f1d945f66122f02f526eac 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel { } }; +class ElementwiseOpInplace : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + return std::unordered_map{ + {"X", "Out"}, + }; + } +}; + } // namespace operators } // namespace paddle @@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel { REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ __ElemwiseOp##op_type##Maker__, \ ::paddle::operators::ElementwiseOpInferVarType, \ - op_type##GradMaker); \ + op_type##GradMaker, \ + ::paddle::operators::ElementwiseOpInplace); \ REGISTER_OPERATOR(op_type##_grad, \ ::paddle::operators::ElementwiseOpExplicitGrad) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 8e80dc0e641c443923076c31e269689b5bc134a7..bb904166c4a19997a57723d9f2e50cc839aae960 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase { } }; +class FlattenOpInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"X", "Out"}, + }; + return inplace_in_to_out; + } +}; + +class FlattenGradInplaceinToOut : public framework::InplaceInToOut { + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {framework::GradVarName("Out"), framework::GradVarName("X")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle @@ -275,10 +304,13 @@ USE_OP(reshape); namespace ops = paddle::operators; REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, ops::FlattenOpInferShape, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape); + paddle::framework::DefaultGradOpDescMaker, + ops::FlattenOpInplaceInToOut); +REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape, + ops::FlattenGradInplaceinToOut); REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, - ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker); + ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker, + ops::FlattenOpInplaceInToOut); REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, - ops::Flatten2GradInferShape); + ops::Flatten2GradInferShape, ops::FlattenGradInplaceinToOut); diff --git a/paddle/fluid/operators/norm_op.h b/paddle/fluid/operators/norm_op.h index 6c95d3f3bf3a3b0448a8f39915f8b025f7d3bd46..f81cbc2c733af2a42f27e2ecb05ee2f8e2f8c17b 100644 --- a/paddle/fluid/operators/norm_op.h +++ b/paddle/fluid/operators/norm_op.h @@ -99,10 +99,10 @@ class NormGradKernel : public framework::OpKernel { auto dx_e = framework::EigenVector::Flatten(*out_dx); Eigen::DSizes shape(pre, n, post); - Eigen::DSizes norm_shape(pre, post); + Eigen::DSizes rshape(pre, 1, post); auto x = x_e.reshape(shape); auto dy = dy_e.reshape(shape); - auto norm = norm_e.reshape(norm_shape); + auto norm = norm_e.reshape(rshape); auto dx = dx_e.reshape(shape); framework::Tensor rsum; @@ -111,7 +111,6 @@ class NormGradKernel : public framework::OpKernel { Eigen::DSizes rdim(1); Eigen::DSizes bcast(1, n, 1); - Eigen::DSizes rshape(pre, 1, post); // dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)] // = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x)) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 32365d6a9602fa8ad2c01b59c8cd361d52ed973f..eda54f76b898cdf893347d31cadb86dea892a4ce 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -327,14 +327,45 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } }; +class ReshapeOpInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"X", "Out"}, + }; + return inplace_in_to_out; + } +}; + +class ReshapeGradInplaceInToOut : public framework::InplaceInToOut { + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {framework::GradVarName("Out"), framework::GradVarName("X")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); + paddle::framework::DefaultGradOpDescMaker, + ops::ReshapeOpInplaceInToOut); +REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, + ops::ReshapeGradInplaceInToOut); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); @@ -344,8 +375,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeGradKernel); REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, - ops::Reshape2GradMaker); -REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); + ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut); +REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, + ops::ReshapeGradInplaceInToOut); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 981969d2aaa684731a615ec64ca7f7718b35cf09..4ea77ed30db212b694f2050952655dd1a42215bd 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { } }; +using ScaleOpInplace = framework::SingleOpInplaceInToOut; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker, - ops::ScaleOpVarTypeInference); + ops::ScaleOpVarTypeInference, ops::ScaleOpInplace); REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index bc889a5a042a27838ba6ba0fccb187ec11b5f0c5..8fbf299a7c056aff3bfd4cbd3e3cc28fd3c6ccf2 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { return std::unique_ptr(op); } }; + +class SoftmaxInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + return std::unordered_map{ + {"X", "Out"}, + }; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 4dcec21952f2eba72574c95303ba728df8746401..6549229e05de5f2a809b56775d9788bbf8e5c1ae 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1096,6 +1096,10 @@ All parameter, weight, gradient are variables in Paddle. "memory_early_delete", [](const BuildStrategy &self) { return self.memory_early_delete_; }, [](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; }) + .def_property( + "enable_inplace", + [](const BuildStrategy &self) { return self.enable_inplace_; }, + [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 564882bd2a23437665777c646e6e399cdffae723..396f36e188b27fe450cc19b3b8ccf967daf1456c 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -158,7 +158,8 @@ def __bootstrap__(): 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'sync_nccl_allreduce', 'limit_of_tmp_allocation', - 'times_excess_than_required_tmp_allocation' + 'times_excess_than_required_tmp_allocation', + 'enable_inplace_whitelist' ] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index a35a4c59835e2a64a11ae156bed34d4b35696f73..ef0242942838fcca737a10fafbafa61bf520b532 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -174,6 +174,11 @@ class CompiledProgram(object): self._exec_strategy.num_threads = cpu_num * 2 trainers_endpoints = self._program._trainers_endpoints + + # FIXME(dzhwinter): enable_inplace should be after memory_optimize + # if turn on python memory optimize, turn off the inplace_pass. + self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True + if self._build_strategy.num_trainers > 1 and trainers_endpoints: assert self._build_strategy.num_trainers == len( trainers_endpoints), "num_trainers == len(end_points)" diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 96587b6e904f681a71182ffdb03608b5edde5e46..c0b0ad8a202b82183de9ec1edd43cb10db10fb5c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1725,6 +1725,19 @@ class Program(object): self._trainers_endpoints = [] # the distributed lookup table names self._distributed_lookup_table = None + # @deprecated(the python memory optimize transpiler is deprecated) + # whether the program is optimized by memory_optimize_transpiler + self.__is_mem_optimized = False + + @property + def _is_mem_optimized(self): + # if the program is optimized, operator input/outputs + # maybe same, which conflict with save_inference_model. + return self.__is_mem_optimized + + @_is_mem_optimized.setter + def _is_mem_optimized(self, target): + self.__is_mem_optimized = target @property def op_role(self): @@ -1744,7 +1757,7 @@ class Program(object): return self._current_role @op_role.setter - def set_op_role(self, role): + def op_role(self, role): self._current_role = role @property diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 95cc05ac7191783969ff0fbf286c17bd1cfd6c7d..a2abbf36c0267d85c9c97af00c9faabf1187822c 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -16,6 +16,7 @@ from __future__ import print_function import os import errno +import warnings import time import shutil import six @@ -931,6 +932,13 @@ def save_inference_model(dirname, if main_program is None: main_program = default_main_program() + if main_program._is_mem_optimized: + warnings.warn( + "save_inference_model must put before you call memory_optimize. \ + the memory_optimize will modify the original program, \ + is not suitable for saving inference model \ + we save the original program as inference model.", + RuntimeWarning) # fix the bug that the activation op's output as target will be pruned. # will affect the inference performance. diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index a07ff6ac69ca20c8c68659a67606076ce8cdf027..52b260efd15066a114a8146106685043654c91ea 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -146,6 +146,9 @@ class ParallelExecutor(object): # step4: get main_program, scope, local_scopes main = main_program if main_program \ else framework.default_main_program() + # FIXME(dzhwinter): enable_inplace should be after memory_optimize + # if turn on python memory optimize, turn off the inplace_pass. + build_strategy.enable_inplace = False if main._is_mem_optimized else True scope = scope if scope is not None else executor.global_scope() if share_vars_from and not isinstance(share_vars_from, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 699181d01da862dca72113e6c11630ae5693e41c..4b26bacce968a6da72e9aa043adb38918b293a35 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -110,6 +110,10 @@ py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executo if(NOT APPLE) py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) endif() +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + # change the timeout from 600 to 900, because in debug mode, this test need more time. + set_tests_properties(test_image_classification_resnet PROPERTIES TIMEOUT 900) +endif() if (WITH_NGRAPH) add_subdirectory(ngraph) diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index fdacd241f9e1f8d442f55098e2d192a3d57fdaf1..c429c8af7d37cb4e209edc41f704868afe054829 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, - use_ir_memory_optimize=False, + use_ir_memory_optimize=True, + enable_inplace=True, fuse_elewise_add_act_ops=False, fuse_relu_depthwise_conv=False, optimizer=fluid.optimizer.Adam, @@ -60,63 +61,65 @@ class TestParallelExecutorBase(unittest.TestCase): main.random_seed = seed loss = method(use_feed=feed_dict is not None) - if optimizer: optimizer().minimize(loss) if memory_opt: fluid.memory_optimize(main) - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(startup) - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.allow_op_delay = allow_op_delay - if use_fast_executor: - exec_strategy.use_experimental_executor = True - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ - if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce - build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops - build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv - build_strategy.memory_optimize = use_ir_memory_optimize - build_strategy.enable_sequential_execution = enable_sequential_execution - if use_cuda and core.is_compiled_with_cuda(): - build_strategy.remove_unnecessary_lock = True - if use_parallel_executor: - binary = compiler.CompiledProgram(main).with_data_parallel( - loss_name=loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - else: - binary = compiler.CompiledProgram(main) - - if batch_size is not None: - batch_size *= fluid.core.get_cuda_device_count( - ) if use_cuda else int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - begin = time.time() - first_loss, = run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) - - for i in range(iter): - run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) - - last_loss, = run_executor( - exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) - end = time.time() - - if batch_size is not None: - print("%.4f Instance per second" % ( - (batch_size * iter + 2) / (end - begin))) - - avg_last_loss_val = np.array(last_loss).mean() - avg_first_loss_val = np.array(first_loss).mean() - if math.isnan(float(avg_last_loss_val)) or math.isnan( - float(avg_first_loss_val)): - sys.exit("got NaN loss, training failed.") - - print(first_loss, last_loss) - # self.assertGreater(first_loss[0], last_loss[0]) - return first_loss, last_loss + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.allow_op_delay = allow_op_delay + if use_fast_executor: + exec_strategy.use_experimental_executor = True + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ + if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv + build_strategy.memory_optimize = use_ir_memory_optimize + # python memory optimization is conflict with inplace pass. + # Use ir graph memory optimization after inplace pass is the correct way. + build_strategy.enable_inplace = False if memory_opt else enable_inplace + build_strategy.enable_sequential_execution = enable_sequential_execution + + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True + if use_parallel_executor: + binary = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + else: + binary = compiler.CompiledProgram(main) + + if batch_size is not None: + batch_size *= fluid.core.get_cuda_device_count( + ) if use_cuda else int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + begin = time.time() + first_loss, = run_executor( + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) + + for i in range(iter): + run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) + + last_loss, = run_executor( + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) + end = time.time() + + if batch_size is not None: + print("%.4f Instance per second" % ( + (batch_size * iter + 2) / (end - begin))) + + avg_last_loss_val = np.array(last_loss).mean() + avg_first_loss_val = np.array(first_loss).mean() + if math.isnan(float(avg_last_loss_val)) or math.isnan( + float(avg_first_loss_val)): + sys.exit("got NaN loss, training failed.") + + print(first_loss, last_loss) + # self.assertGreater(first_loss[0], last_loss[0]) + return first_loss, last_loss diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py index 754d5fd40953311a5deb466fa42216f72671a65a..603c8e74885d2a050e6e1e3101dce880b6eabe9c 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py @@ -16,12 +16,10 @@ import os import unittest os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" -from test_parallel_executor_transformer import TestTransformer - - -class EagerDeletionTestTransformer(TestTransformer): - pass +os.environ[ + 'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio' +from test_parallel_executor_transformer import TestTransformer if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index 3b54827dd2e5ba177cb1a91019581c3fb6f63bb5..9c9f86330704466c7a8801af6ab0fb2bba23f931 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -25,6 +25,7 @@ import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer from paddle.fluid.framework import Program, program_guard from paddle.fluid.io import save_inference_model, load_inference_model +from paddle.fluid.transpiler import memory_optimize class TestBook(unittest.TestCase): @@ -87,5 +88,31 @@ class TestBook(unittest.TestCase): self.assertEqual(expected, actual) +class TestSaveInferenceModel(unittest.TestCase): + def test_save_inference_model(self): + MODEL_DIR = "./tmp/inference_model2" + init_program = Program() + program = Program() + + # fake program without feed/fetch + with program_guard(program, init_program): + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[1], dtype='float32') + + y_predict = layers.fc(input=x, size=1, act=None) + + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(cost) + + place = core.CPUPlace() + exe = executor.Executor(place) + exe.run(init_program, feed={}, fetch_list=[]) + + memory_optimize(program, print_log=True) + self.assertEqual(program._is_mem_optimized, True) + # will print warning message + save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4e196758efc990506957089fb5b88ebb099cca29 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -0,0 +1,76 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid +from parallel_executor_test_base import TestParallelExecutorBase + + +def fc_with_batchnorm(use_feed): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + hidden = img + for _ in range(3): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestIrInplace(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + os.environ['CPU_NUM'] = str(4) + + def _fc_with_batchnorm(self, + ir_memory_optimize, + enable_inplace, + memory_opt=False): + + if not core.is_compiled_with_cuda(): + return + np.random.seed(5) + img = np.random.random(size=[32, 784]).astype(np.float32) + label = np.ones(shape=[32, 1], dtype='int64') + self.check_network_convergence( + fc_with_batchnorm, + feed_dict={"image": img, + "label": label}, + use_cuda=True, + memory_opt=memory_opt, + use_ir_memory_optimize=ir_memory_optimize, + enable_inplace=enable_inplace) + + def test_fc_with_batchnorm(self, delta=1e-3): + loss00 = self._fc_with_batchnorm(False, False) + loss10 = self._fc_with_batchnorm(True, False) + loss01 = self._fc_with_batchnorm(False, True) + loss11 = self._fc_with_batchnorm(True, True) + self.assertAlmostEqual(loss00, loss10, delta=delta) + self.assertAlmostEqual(loss00, loss01, delta=delta) + self.assertAlmostEqual(loss00, loss11, delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index e7a56bb6386a812e43e5c1b5c08cd0682aa9223a..9548598d75367ed1f1863d1f6ae50b83d58f8c7f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -200,7 +200,7 @@ class TestResnet(TestParallelExecutorBase): model, use_cuda, iter=20, - delta2=1e-6): + delta2=1e-5): if use_cuda and not core.is_compiled_with_cuda(): return @@ -228,7 +228,7 @@ class TestResnet(TestParallelExecutorBase): optimizer=optimizer) for loss in zip(all_reduce_first_loss, reduce_first_loss): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + self.assertAlmostEquals(loss[0], loss[1], delta=1e-5) for loss in zip(all_reduce_last_loss, reduce_last_loss): self.assertAlmostEquals(loss[0], loss[1], delta=delta2) @@ -258,17 +258,17 @@ class TestResnet(TestParallelExecutorBase): enable_sequential_execution=True) for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + self.assertAlmostEquals(loss[0], loss[1], delta=1e-5) for loss in zip(all_reduce_last_loss, all_reduce_last_loss_seq): self.assertAlmostEquals(loss[0], loss[1], delta=delta2) for loss in zip(reduce_first_loss, reduce_first_loss_seq): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + self.assertAlmostEquals(loss[0], loss[1], delta=1e-5) for loss in zip(reduce_last_loss, reduce_last_loss_seq): self.assertAlmostEquals(loss[0], loss[1], delta=delta2) for loss in zip(all_reduce_first_loss_seq, reduce_first_loss_seq): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + self.assertAlmostEquals(loss[0], loss[1], delta=1e-5) for loss in zip(all_reduce_last_loss_seq, reduce_last_loss_seq): self.assertAlmostEquals(loss[0], loss[1], delta=delta2) @@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase): use_cuda=True, use_reduce=False, iter=20, - delta2=1e-6): + delta2=1e-5): if use_cuda and not core.is_compiled_with_cuda(): return @@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase): optimizer=optimizer) self.assertAlmostEquals( - np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6) + np.mean(parallel_first_loss), single_first_loss[0], delta=1e-5) self.assertAlmostEquals( np.mean(parallel_last_loss), single_last_loss[0], delta=delta2) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 3827743908c1d76931572277323d1dd5ddd05523..aacc1c3ecda8c25dec9f08827a856d38c37b1b2f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -24,7 +24,7 @@ import paddle.fluid.core as core import paddle.dataset.wmt16 as wmt16 import os -WMT16_RECORDIO_FILE = "/tmp/wmt16.recordio" +WMT16_RECORDIO_FILE = os.environ.get('RECORDIO_FILENAME', '/tmp/wmt16.recordio') class ModelHyperParams(object): diff --git a/python/paddle/fluid/tests/unittests/transformer_model.py b/python/paddle/fluid/tests/unittests/transformer_model.py index 143d187edc3a154418f9e639b7d492c8ce994d42..905b7d6fe75ab0080e3e97fbd4710ad913a05a38 100644 --- a/python/paddle/fluid/tests/unittests/transformer_model.py +++ b/python/paddle/fluid/tests/unittests/transformer_model.py @@ -17,6 +17,7 @@ from __future__ import print_function from functools import partial import numpy as np +import os import paddle.fluid as fluid import paddle.fluid.layers as layers from paddle.fluid.layers.io import open_recordio_file @@ -408,7 +409,7 @@ def transformer( trg_pad_idx, pos_pad_idx, ): file_obj = open_recordio_file( - filename='/tmp/wmt16.recordio', + filename=os.environ.get('RECORDIO_FILENAME', '/tmp/wmt16.recordio'), shapes=[ [batch_size * max_length, 1], [batch_size * max_length, 1], diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index e5d48d3d19ed71624d528144f13e23770a09362a..52c1aea288fa2bb7478ad14186367900c05f64e7 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -540,6 +540,7 @@ def memory_optimize(input_program, if skip_opt_set is not None: skip_opt_set = set(map(to_name_str, skip_opt_set)) cfgs = _get_cfgs(input_program) + input_program._is_mem_optimized = True for cfg in cfgs: cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) @@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None): None """ cfgs = _get_cfgs(input_program) + input_program._is_mem_optimized = True for cfg in cfgs: cfg.release_memory(skip_opt_set=skip_opt_set)