未验证 提交 74bc55c2 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14975 from dzhwinter/ir_inplace_pass

Ir inplace pass
...@@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
...@@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto) ...@@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc) proto_desc)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
......
...@@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass) cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper)
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
...@@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he ...@@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph) cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass) cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
...@@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
// NOTE(dzhwinter): A note for automatical inplace.
// 1. modify program desc passes should put
// before inplace pass.
// 2. manually configured inplace should put
// before inplace_pass
// Add automatically inplace.
if (strategy_.enable_inplace_) {
AppendPass("inplace_pass");
}
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
...@@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path)); viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
} }
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
if (strategy.fuse_elewise_add_act_ops_) { if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
...@@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) { if (strategy.memory_optimize_) {
auto analysis_var_pass = AppendPass("analysis_var_pass"); auto memory_optimize_pass = AppendPass("memory_optimize_pass");
} }
AppendMultiDevPass(strategy); AppendMultiDevPass(strategy);
...@@ -186,8 +198,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -186,8 +198,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "memory_optimize_pass") {
} else if (pass->Type() == "analysis_var_pass") { if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
const std::vector<OpDesc *> *all_op_descs = const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps()); new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs, graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
...@@ -214,6 +228,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -214,6 +228,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Set<const std::vector<OpDesc *>>( pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs, kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps())); new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "inplace_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
graph->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
if (!use_cuda) { if (!use_cuda) {
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
...@@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass); ...@@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass); USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(analysis_var_pass); USE_PASS(memory_optimize_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass); USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass); USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_to_program_pass); USE_PASS(graph_to_program_pass);
...@@ -80,6 +80,11 @@ struct BuildStrategy { ...@@ -80,6 +80,11 @@ struct BuildStrategy {
bool memory_early_delete_{false}; bool memory_early_delete_{false};
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
bool enable_inplace_{false};
bool enable_sequential_execution_{false}; bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SplitOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "");
AddOutput("Out", "").AsDuplicable();
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include <algorithm>
#include <deque>
#include <iterator>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
// By our design, one operator only can read its input(const Variable),
// write its output(non-const Variable). If one operator is inplaced, means
// user have chance to write the space before reading happens.
// Especially when some optimize code writing style is applied.
//
//
// /* wrong case in operator */
// /*In this case, a larger allocation is allocated, input content is lost*/
// const Tensor* in = ctx.Input<Tensor>("In")
// Tensor* out = ctx.Output<Tensor>("Out");
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// NOTE(dzhwinter):
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
// on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool(
enable_inplace_whitelist, false,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned on");
DECLARE_string(memory_optimize_debug);
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
"sigmoid",
"exp",
"relu",
"tanh",
"sqrt",
"ceil",
"floor",
"reciprocal",
"relu6",
"soft_relu",
"hard_sigmoid",
"batch_norm",
"batch_norm_grad",
"sum",
"sum_grad",
"scale",
"reshape",
"elementwise_add",
"elementwise_add_grad",
};
// clang-format on
namespace paddle {
namespace framework {
namespace details {
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr;
for (auto* next_op : var->outputs) {
for (auto* output : next_op->outputs) {
if (output->IsVar() && !output->IsCtrlVar() &&
output->Name() == var->Name()) {
inplaced_var = output;
}
}
}
return inplaced_var;
}
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
if (var->inputs.empty()) return nullptr;
auto* prev_op = var->inputs.at(0);
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
[&](ir::Node* node) {
if (node->IsVar() && !node->IsCtrlVar() &&
node->Name() == var->Name()) {
return true;
} else {
return false;
}
});
return input_it == prev_op->inputs.end() ? nullptr : *input_it;
}
InplacePass::InplacePass() : Pass() {
if (FLAGS_enable_inplace_whitelist) {
for (auto& s : kInplacedOpWhiteList) {
whitelist_.emplace(s);
}
}
}
void InplacePass::InitSSAGraphNodes() const {
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars;
for (auto* op : view_.AllOps()) {
for (auto* node : op->inputs) {
if (!node->IsVar() || node->IsCtrlVar()) continue;
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
}
}
for (auto* node : op->outputs) {
if (!node->IsVar() || node->IsCtrlVar()) continue;
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
}
}
}
}
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
var_nodes_.clear();
view_.Build(graph.get());
InitSSAGraphNodes();
for (auto* op : view_.AllOps()) {
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
}
graph->ResolveHazard(var_nodes_);
return graph;
}
void InplacePass::InplaceModifyDesc(const std::string& var,
const std::string& cache_var,
const size_t& idx) const {
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
ir::Node* op = view_.AllOps()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
op_desc->RenameOutput(var, cache_var);
if (op_desc->Block()->HasVar(var)) op_desc->Block()->RemoveVar(var);
op_desc->Flush();
}
}
const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
SSANodePair swap_nodes;
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes.emplace_back(std::make_pair(node, cache_node));
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes.emplace_back(std::make_pair(node, cache_node));
}
}
}
return swap_nodes;
}
void InplacePass::CommitModify(const SSANodePair& swap_nodes,
ir::Graph* graph) const {
for (auto& pair : swap_nodes) {
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
void InplacePass::WithdrawModify(const SSANodePair& nodes,
ir::Graph* graph) const {
for (auto& pair : nodes) {
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
}
graph->RemoveNode(cache_node);
}
}
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const {
VLOG(4) << "Try to inplace op " << op->Name();
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
"op_desc is nullptr");
// some pre-requirments need to meet if the op want to inplaced.
auto* op_desc = op->Op();
auto& infer_inplace =
OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_;
// 1. infer_inplace_ is registered.
if (!static_cast<bool>(infer_inplace)) return;
PADDLE_ENFORCE(static_cast<bool>(infer_inplace),
"%s's infer_inplace has not been registered", op_desc->Type());
auto* block = op_desc->Block();
auto in_to_outs = infer_inplace(*op_desc, block);
auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor);
for (auto& pair : in_to_outs) {
auto& in_var_name = pair.first;
auto& out_var_name = pair.second;
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
// 2. there is no external pending op on the input node
if (view_.PendingOpsOnVar(in_node).size() > 1) {
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory.",
out_var_name, in_var_name, op->Name());
continue;
}
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if (view_.InSkipSet(out_node->Name())) {
VLOG(4) << string::Sprintf(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle",
out_var_name, in_var_name, op->Name());
continue;
}
// Debug Interface. Which would be skipped by the pass.
if (out_node->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug="
<< out_node->Name();
continue;
}
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
InplaceModifyDesc(out_var_name, in_var_name, idx);
CommitModify(swap_nodes, graph);
} else {
VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name());
WithdrawModify(swap_nodes, graph);
}
}
}
ir::Node* GraphView::GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const {
// nodes should be op->inputs/outputs
// node in same node do have different name.
std::unordered_set<std::string> nodes_in_op;
bool has_dup_node =
std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) {
if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) {
if (nodes_in_op.count(node->Name())) return true;
nodes_in_op.emplace(node->Name());
}
return false;
});
PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!");
ir::Node* node = nullptr;
for (auto* it : nodes) {
if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue;
if (it->Name() == name) {
node = it;
break;
}
}
PADDLE_ENFORCE(node != nullptr,
string::Sprintf("Not found var %s in nodes!", name));
return node;
}
std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
// get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std::vector<ir::Node*> pending_ops;
ir::Node* p = node;
while (p != nullptr) {
pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end());
p = GetPrevCascadeInplacedVar(p);
}
return pending_ops;
}
void GraphView::Build(ir::Graph* g) {
// track the var nodes in correct order.
// Because we insert some new created node. Which may have data race between
// nodes.
// resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder(*g);
// 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std::unordered_set<std::string> all_vars;
for (auto& node : g->Nodes()) {
if (node->IsVar()) continue;
for (auto& out : node->outputs) {
if (out->IsCtrlVar() || out->Var() == nullptr) continue;
if (all_vars.count(out->Name())) {
dup_nodes_.emplace(out->Name());
} else {
all_vars.emplace(out->Name());
}
}
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
if (node->Name() == "send") {
for (auto& in : node->inputs) {
dup_nodes_.emplace(in->Name());
}
}
if (node->Name() == "recv") {
for (auto& out : node->outputs) {
dup_nodes_.emplace(out->Name());
}
}
}
}
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::InSkipSet(const std::string& var) const {
return dup_nodes_.count(var);
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class GraphView {
public:
GraphView() = default;
void Build(ir::Graph* g);
const std::vector<ir::Node*>& AllOps();
ir::Node* GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const;
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future.
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
private:
std::vector<ir::Node*> ops_;
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
};
typedef std::vector<std::pair<ir::Node*, ir::Node*>> SSANodePair;
class InplacePass : public ir::Pass {
public:
InplacePass();
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void InitSSAGraphNodes() const;
private:
const SSANodePair TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const SSANodePair&, ir::Graph* graph) const;
void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
mutable std::unordered_set<std::string> whitelist_;
mutable GraphView view_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <queue> #include <queue>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <functional>
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -21,15 +23,17 @@ namespace paddle { ...@@ -21,15 +23,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
size_t NodeSizeInBytes(const VarDesc& node) {
auto shape = node.GetShape();
int size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
size_t type_size = SizeOfType(node.GetDataType());
return type_size * std::abs(size);
}
size_t NodeSizeInBytes(ir::Node* n) { size_t NodeSizeInBytes(ir::Node* n) {
auto* desc = FindVarDescInBlock(n); auto* desc = FindVarDescInBlock(n);
auto shape = desc->GetShape(); return NodeSizeInBytes(*desc);
size_t type_size = SizeOfType(desc->GetDataType());
int size = 1;
for (auto& s : shape) {
size *= s;
}
return type_size * std::abs(size);
} }
std::string DebugStringImpl(VarDesc* var) { std::string DebugStringImpl(VarDesc* var) {
...@@ -83,7 +87,7 @@ struct NodeComparator { ...@@ -83,7 +87,7 @@ struct NodeComparator {
} }
}; };
void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
PADDLE_ENFORCE(op->IsOp()); PADDLE_ENFORCE(op->IsOp());
if (mark_table_.count(var->Name()) != 0) { if (mark_table_.count(var->Name()) != 0) {
...@@ -119,11 +123,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { ...@@ -119,11 +123,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
mark_table_[var->Name()] = it; mark_table_[var->Name()] = it;
} }
int OrderedNodePairPool::GetIndex(ir::Node* var) { int OrderedNodeList::GetIndex(ir::Node* var) {
return std::distance(nodes_.begin(), mark_table_[var->Name()]); return std::distance(nodes_.begin(), mark_table_[var->Name()]);
} }
ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { ir::Node* OrderedNodeList::NodeMatch(ir::Node* var) const {
ir::Node* found_node = nullptr; ir::Node* found_node = nullptr;
NodeComparator compare_node; NodeComparator compare_node;
...@@ -136,13 +140,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { ...@@ -136,13 +140,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
return found_node; return found_node;
} }
void OrderedNodePairPool::Erase(ir::Node* var) { void OrderedNodeList::Erase(ir::Node* var) { Erase(var->Name()); }
PADDLE_ENFORCE(mark_table_.count(var->Name()));
nodes_.erase(mark_table_[var->Name()]); void OrderedNodeList::Erase(const std::string& var) {
mark_table_.erase(var->Name()); PADDLE_ENFORCE(mark_table_.count(var));
nodes_.erase(mark_table_[var]);
mark_table_.erase(var);
} }
std::string OrderedNodePairPool::ToString() const { std::string OrderedNodeList::ToString() const {
std::stringstream ss; std::stringstream ss;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
ss << DebugString(it->first) << " "; ss << DebugString(it->first) << " ";
...@@ -150,6 +156,43 @@ std::string OrderedNodePairPool::ToString() const { ...@@ -150,6 +156,43 @@ std::string OrderedNodePairPool::ToString() const {
return ss.str(); return ss.str();
} }
bool NodeCanReused(ir::Node* node) {
if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false;
// auto* desc = node->Var();
bool flag = NodeCanReused(*node->Var());
for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) {
// op output force generated in cpu, can not be reused.
flag &= framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0;
}
}
return flag;
}
bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
node.GetShape().empty()) {
return false;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std::string name = node.Name();
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
return false;
return true;
}
bool OpHasSubBlock(OpDesc* desc) {
const AttributeMap& attrs = desc->GetAttrMap();
for (auto& attr : attrs) {
if (attr.second.type() == typeid(BlockDesc*) || // NOLINT
attr.second.type() == typeid(std::vector<BlockDesc*>)) // NOLINT
return true;
}
return false;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -43,7 +43,7 @@ using GraphNodePool = std::vector< ...@@ -43,7 +43,7 @@ using GraphNodePool = std::vector<
// For example, // For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], .. // node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete // O(1) insert, delete
class OrderedNodePairPool { class OrderedNodeList {
public: public:
using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>; using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
using Iter = typename std::list<NodePair>::iterator; using Iter = typename std::list<NodePair>::iterator;
...@@ -53,8 +53,12 @@ class OrderedNodePairPool { ...@@ -53,8 +53,12 @@ class OrderedNodePairPool {
void Erase(ir::Node* var); void Erase(ir::Node* var);
void Erase(const std::string& var);
bool Has(ir::Node* var) { return mark_table_.count(var->Name()); } bool Has(ir::Node* var) { return mark_table_.count(var->Name()); }
bool Has(const std::string& var) { return mark_table_.count(var); }
ir::Node* NodeMatch(ir::Node* var) const; ir::Node* NodeMatch(ir::Node* var) const;
// map store non-const iterator, can not promise const // map store non-const iterator, can not promise const
int GetIndex(ir::Node* var); int GetIndex(ir::Node* var);
...@@ -67,6 +71,11 @@ class OrderedNodePairPool { ...@@ -67,6 +71,11 @@ class OrderedNodePairPool {
ConstIter end() const { return nodes_.end(); } ConstIter end() const { return nodes_.end(); }
size_t size() const { return nodes_.size(); } size_t size() const { return nodes_.size(); }
void Clear() {
mark_table_.clear();
nodes_.clear();
}
private: private:
// for searching. // for searching.
std::unordered_map<std::string, Iter> mark_table_; std::unordered_map<std::string, Iter> mark_table_;
...@@ -74,14 +83,53 @@ class OrderedNodePairPool { ...@@ -74,14 +83,53 @@ class OrderedNodePairPool {
std::list<NodePair> nodes_; std::list<NodePair> nodes_;
}; };
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node);
// valid a tensor can be reuse or not.
bool NodeCanReused(const VarDesc& node);
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc);
// node memory size in bytes // node memory size in bytes
size_t NodeSizeInBytes(ir::Node* n); size_t NodeSizeInBytes(ir::Node* n);
// node memory size in bytes
size_t NodeSizeInBytes(const VarDesc&);
std::string DebugString(ir::Node* var); std::string DebugString(ir::Node* var);
// std::string DebugString(VarDesc* var);
VarDesc* FindVarDescInBlock(ir::Node* n); VarDesc* FindVarDescInBlock(ir::Node* n);
template <typename Container, typename Callback>
class FilterVariableImpl {
public:
void operator()(const Container& nodes, Callback callback) {
for (auto* node : nodes) {
callback(node);
}
}
};
// filter var node for op->inputs/outputs
template <typename Callback>
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
public:
void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
for (auto* var : nodes) {
if (var->IsVar() && !var->IsCtrlVar()) {
callback(var);
}
}
}
};
template <typename Container, typename Callback>
void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
...@@ -27,8 +27,8 @@ namespace paddle { ...@@ -27,8 +27,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
TEST(OrderedNodePairPool, Normal) { TEST(OrderedNodeList, Normal) {
OrderedNodePairPool pool; OrderedNodeList pool;
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
// clang-format off // clang-format off
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/analysis_var_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <deque> #include <deque>
...@@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { ...@@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
op1->Outputs() == op2->Outputs(); op1->Outputs() == op2->Outputs();
} }
template <typename Container, typename Callback> std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
class FilterVariableImpl {
public:
void operator()(const Container& nodes, Callback callback) {
for (auto* node : nodes) {
callback(node);
}
}
};
// filter var node for op->inputs/outputs
template <typename Callback>
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
public:
void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
for (auto* var : nodes) {
if (var->IsVar() && !var->IsCtrlVar()) {
callback(var);
}
}
}
};
template <typename Container, typename Callback>
void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
auto subblock_vars = GetSubBlockVars(nodes); auto subblock_vars = GetSubBlockVars(nodes);
...@@ -103,8 +75,11 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl( ...@@ -103,8 +75,11 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
} }
for (auto& var : op->outputs) { for (auto& var : op->outputs) {
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) { if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 ||
skip_set_.count(var->Name()))
continue;
ir::Node* cache = pool_.NodeMatch(var); ir::Node* cache = pool_.NodeMatch(var);
if (var->Name() == FLAGS_memory_optimize_debug) { if (var->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "start match var " << DebugString(var) << " of op " VLOG(3) << "start match var " << DebugString(var) << " of op "
<< op->Name(); << op->Name();
...@@ -112,14 +87,14 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl( ...@@ -112,14 +87,14 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
VLOG(3) << "matched in pool : " VLOG(3) << "matched in pool : "
<< ((cache == nullptr) ? "False" : "True"); << ((cache == nullptr) ? "False" : "True");
} }
if (cache != nullptr) {
if (cache == nullptr) continue;
if (var->Name() == cache->Name()) { if (var->Name() == cache->Name()) {
VLOG(3) << "The same cache variable is cascade reused." VLOG(3) << "The same cache variable is cascade reused." << var->Name()
<< var->Name() << " is re-filled to the pool after" << " is re-filled to the pool after"
<< "the reused op is finished. Current op can not " << "the reused op is finished. Current op can not "
<< "replace it again. Skip this candidate."; << "replace it again. Skip this candidate.";
continue; continue;
}
int node_idx_in_pool = pool_.GetIndex(cache); int node_idx_in_pool = pool_.GetIndex(cache);
VLOG(3) << string::Sprintf( VLOG(3) << string::Sprintf(
...@@ -138,13 +113,15 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl( ...@@ -138,13 +113,15 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
pool_.Erase(cache); pool_.Erase(cache);
} }
}
}
// fill the pool // fill the pool
std::unordered_set<std::string> unlived_vars;
for (auto var : cfg_->LiveIn(op)) { for (auto var : cfg_->LiveIn(op)) {
if (cfg_->LiveOut(op).count(var) == 0) { if (cfg_->LiveOut(op).count(var) == 0) {
unlived_vars.emplace(var);
}
}
for (auto var : unlived_vars) {
ir::Node* var_node = cfg_->GetNodeFromVarName(var, op); ir::Node* var_node = cfg_->GetNodeFromVarName(var, op);
if (var_node == nullptr) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) { if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node, op); pool_.Insert(var_node, op);
} }
...@@ -177,7 +154,7 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl( ...@@ -177,7 +154,7 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
return graph; return graph;
} }
void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
// conditional block, while op and their grad op // conditional block, while op and their grad op
auto* sub_block_desc = auto* sub_block_desc =
AttrReader(op_desc->GetAttrMap()).Get<BlockDesc*>("sub_block"); AttrReader(op_desc->GetAttrMap()).Get<BlockDesc*>("sub_block");
...@@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { ...@@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const {
} }
} }
std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars( std::unordered_set<std::string> MemoryOptimizePass::GetSubBlockVars(
const std::unordered_set<ir::Node*>& nodes) const { const std::unordered_set<ir::Node*>& nodes) const {
std::unordered_set<std::string> vars; std::unordered_set<std::string> vars;
for (auto& op : nodes) { for (auto& op : nodes) {
...@@ -263,7 +240,7 @@ std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars( ...@@ -263,7 +240,7 @@ std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars(
return vars; return vars;
} }
void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, void MemoryOptimizePass::RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var, const std::string& cache_var,
size_t idx) const { size_t idx) const {
for (size_t i = idx; i < cfg_->Ops().size(); ++i) { for (size_t i = idx; i < cfg_->Ops().size(); ++i) {
...@@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, ...@@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var,
} }
} }
void AnalysisVarPass::InitSSAGraphNodes() const { void MemoryOptimizePass::InitSSAGraphNodes() const {
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars; std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars;
if (var_nodes_.empty()) { if (var_nodes_.empty()) {
for (auto* op : cfg_->Ops()) { for (auto* op : cfg_->Ops()) {
...@@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const { ...@@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const {
} }
} }
void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, const std::string& cache_var,
size_t idx, ir::Graph* graph) const { size_t idx,
ir::Graph* graph) const {
// if replace happens, we need to create a newer version cache_var // if replace happens, we need to create a newer version cache_var
// but use the same dims/data_type with var. // but use the same dims/data_type with var.
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
...@@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, ...@@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var,
var_nodes_.at(var).clear(); var_nodes_.at(var).clear();
} }
bool AnalysisVarPass::NodeCanReused(ir::Node* node) const {
if (!node->IsVar() || node->IsCtrlVar()) return false;
auto* desc = node->Var();
auto type = desc->GetType();
if (desc->Persistable() || type != proto::VarType::LOD_TENSOR ||
desc->GetShape().empty()) {
return false;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std::string name = node->Name();
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
return false;
if (skip_set_.count(name)) return false;
for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) {
// op output force generated in cpu, can not be reused.
return framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0;
}
}
return true;
}
bool AnalysisVarPass::OpHasSubBlock(OpDesc* desc) const {
const AttributeMap& attrs = desc->GetAttrMap();
for (auto& attr : attrs) {
if (attr.second.type() == typeid(BlockDesc*) || // NOLINT
attr.second.type() == typeid(std::vector<BlockDesc*>)) // NOLINT
return true;
}
return false;
}
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) { std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kAllOpDescs), PADDLE_ENFORCE(graph.Has(kAllOpDescs),
"Graph has no attribute of kAllOpDescs."); "Graph has no attribute of kAllOpDescs.");
...@@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name, ...@@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name,
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(analysis_var_pass, paddle::framework::details::AnalysisVarPass) REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kGraphNodePool) .RequireGraphAttr(paddle::framework::details::kGraphNodePool)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kAllOpDescs);
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -35,12 +35,10 @@ namespace details { ...@@ -35,12 +35,10 @@ namespace details {
constexpr char kAllOpDescs[] = "all_op_descs"; constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph); std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// sort op in bfs order
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
class ControlFlowGraph; class ControlFlowGraph;
class AnalysisVarPass : public ir::Pass { class MemoryOptimizePass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
...@@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass { ...@@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass {
ir::Graph* graph) const; ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const; void SubGraphOptimize(OpDesc* op_desc) const;
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node) const;
// scan subblock and collect the output/input variables. // scan subblock and collect the output/input variables.
std::unordered_set<std::string> GetSubBlockVars( std::unordered_set<std::string> GetSubBlockVars(
const std::unordered_set<ir::Node*>&) const; const std::unordered_set<ir::Node*>&) const;
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc) const;
private: private:
// Reuse Node Pool, Owned. // Reuse Node Pool, Owned.
mutable OrderedNodePairPool pool_; mutable OrderedNodeList pool_;
// controlflow Graph // controlflow Graph
mutable std::unique_ptr<ControlFlowGraph> cfg_; mutable std::unique_ptr<ControlFlowGraph> cfg_;
// skip set // skip set
......
...@@ -12,63 +12,19 @@ ...@@ -12,63 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/analysis_var_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp, REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker, paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference); paddle::framework::DummyVarTypeInference);
...@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() { ...@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
return prog; return prog;
} }
template <typename Container>
inline static std::string DebugString(const Container& c) {
std::stringstream ss;
for (auto& item : c) {
ss << item << " ";
}
return ss.str();
}
TEST(CFGGraph, IRGraph) { TEST(CFGGraph, IRGraph) {
// prepare ir graph // prepare ir graph
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -32,7 +33,8 @@ enum OpInfoFillType { ...@@ -32,7 +33,8 @@ enum OpInfoFillType {
kOpProtoAndCheckerMaker = 1, kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2, kGradOpDescMaker = 2,
kVarTypeInference = 3, kVarTypeInference = 3,
kShapeInference = 4 kShapeInference = 4,
kInplaceOpInference = 5
}; };
template <typename T> template <typename T>
...@@ -48,8 +50,11 @@ struct OpInfoFillTypeID { ...@@ -48,8 +50,11 @@ struct OpInfoFillTypeID {
? kVarTypeInference ? kVarTypeInference
: (std::is_base_of<InferShapeBase, T>::value : (std::is_base_of<InferShapeBase, T>::value
? kShapeInference ? kShapeInference
: (std::is_base_of<
InplaceOpInference, T>::value
? kInplaceOpInference
: static_cast<OpInfoFillType>( : static_cast<OpInfoFillType>(
-1))))); -1))))));
} }
}; };
...@@ -139,6 +144,16 @@ struct OpInfoFiller<T, kShapeInference> { ...@@ -139,6 +144,16 @@ struct OpInfoFiller<T, kShapeInference> {
} }
}; };
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_inplace_ = [](const OpDesc& op_desc, BlockDesc* block) {
T infer;
return infer(op_desc, block);
};
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
namespace framework {
/*
Inplace Inference for create In->Out pairs for inplaced operator.
If we specify a pair of corresponding names. For example, X->Out.
then Out will inplaced use X's memory. The base class will do
legality validation for both variables.
*/
class InplaceOpInference {
public:
virtual ~InplaceOpInference() {}
virtual std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const = 0;
};
class InplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const {
std::unordered_map<std::string, std::string> ret;
auto in_out_var_names_pair = this->Apply(op_desc, block);
for (auto& pair : in_out_var_names_pair) {
PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(),
string::Sprintf("op %s do not have input of %s!",
op_desc.Type(), pair.first));
PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(),
string::Sprintf("op %s do not have output of %s!",
op_desc.Type(), pair.second));
auto& in_name = op_desc.Input(pair.first).at(0);
auto& out_name = op_desc.Output(pair.second).at(0);
auto in = block->FindRecursiveOrCreateVar(in_name);
auto out = block->FindRecursiveOrCreateVar(out_name);
if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name});
}
return ret;
}
protected:
virtual std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const = 0;
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
return in.Name() != out.Name() && details::NodeCanReused(in) &&
details::NodeCanReused(out) &&
details::NodeSizeInBytes(out) <= details::NodeSizeInBytes(in);
}
};
/*
Inplace In and Out for operator only have an Input and an Output.
For example, activation op.
*/
class SingleOpInplaceInToOut : public InplaceInToOut {
protected:
std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const override {
PADDLE_ENFORCE(!op_desc.InputNames().empty(),
"Op inputs must not be empty");
PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
"Op outputs must not be empty");
auto x_name = op_desc.InputNames().at(0);
auto out_name = op_desc.OutputNames().at(0);
return std::unordered_map<std::string, std::string>{{x_name, out_name}};
}
};
/*
Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy.
*/
class GradOpInplaceInToOut : public InplaceInToOut {
protected:
std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const override {
std::unordered_map<std::string, std::string> ret;
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
op_desc.OutputNames().end());
for (auto& input_name : op_desc.InputNames()) {
if (output_names.count(GradVarName(input_name))) {
ret.insert({input_name, GradVarName(input_name)});
}
}
return ret;
}
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iterator>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SingleOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SingleGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("single_op_grad");
op->SetInput("Out", OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<OpDesc>(op);
}
};
class SingleOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->HasInput("X");
ctx->HasOutput("Out");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
class SingleGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->HasInput(framework::GradVarName("Out"));
ctx->HasOutput(framework::GradVarName("X"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
};
class MultiOutOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddInput("Y", "").AsDuplicable();
AddInput("Z", "").AsDuplicable();
AddOutput("Out", "");
AddOutput("YOut", "");
AddOutput("ZOut", "");
AddOutput("NotReuseOut", "");
AddComment("");
}
};
class MultiOutShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("X", "Out");
ctx->ShareDim("Y", "YOut");
ctx->ShareDim("Z", "ZOut");
}
};
class MultiGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("multi_out_grad");
op->SetInput("X", Input("X"));
op->SetOutput(framework::GradVarName("Y"), OutputGrad("YOut"));
op->SetOutput(framework::GradVarName("X"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Z"), OutputGrad("ZOut"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
class MultiOutGradShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Y"),
ctx->GetInputDim(framework::GradVarName("YOut")));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->SetOutputDim(framework::GradVarName("Z"),
ctx->GetInputDim(framework::GradVarName("ZOut")));
}
};
class MultiOutInplaceInToOut : public framework::InplaceInToOut {
public:
using framework::InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"},
};
}
};
class MultiOutGradInplaceInToOut : public framework::InplaceInToOut {
public:
using framework::InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const override {
return std::unordered_map<std::string, std::string>{
{framework::GradVarName("YOut"), framework::GradVarName("Y")},
{framework::GradVarName("Out"), framework::GradVarName("X")},
{framework::GradVarName("ZOut"), framework::GradVarName("Z")},
};
}
};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
REGISTER_OPERATOR(single_op, f::NOP, f::SingleOpMaker, f::SingleGradOpMaker,
f::SingleOpInplaceInToOut, f::SingleOpShapeInference);
REGISTER_OPERATOR(single_op_grad, f::NOP, f::SingleOpInplaceInToOut,
f::SingleGradOpShapeInference);
REGISTER_OPERATOR(multi_out_op, f::NOP, f::MultiOutOpMaker, f::MultiGradOpMaker,
f::MultiOutInplaceInToOut, f::MultiOutShapeInference);
REGISTER_OPERATOR(multi_out_grad, f::NOP, f::MultiOutGradInplaceInToOut,
f::MultiOutGradShapeInference);
namespace paddle {
namespace framework {
TEST(InferInplace, SingleOpInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("single_op");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 64});
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_out");
prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block());
EXPECT_EQ(in_to_outs.size(), 1ul);
auto it = in_to_outs.begin();
EXPECT_EQ(it->first, "test2_a");
EXPECT_EQ(it->second, "test2_out");
}
TEST(InferInplace, SingleGradOpInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("single_op_grad");
op->SetInput(GradVarName("Out"), {"test2_out"});
op->SetOutput(GradVarName("X"), {"test2_a", "test2_b", "test2_c"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 16});
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_out");
prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block());
EXPECT_EQ(in_to_outs.size(), 1ul);
auto it = in_to_outs.begin();
EXPECT_EQ(it->first, "test2_out");
EXPECT_EQ(it->second, "test2_a");
}
TEST(InferInplace, MultiOutInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("multi_out_op");
op->SetInput("X", {"a0", "a1"});
op->SetInput("Y", {"b0"});
op->SetInput("Z", {"c0", "c1"});
op->SetOutput("Out", {"o0"});
op->SetOutput("YOut", {"y0"});
op->SetOutput("ZOut", {"z0"});
prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block());
EXPECT_EQ(in_to_outs.size(), 3ul);
std::unordered_map<std::string, std::string> expects = {
{"a0", "o0"}, {"b0", "y0"}, {"c0", "z0"},
};
EXPECT_TRUE(expects == in_to_outs);
}
TEST(InferInplace, MultiGradInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("multi_out_grad");
op->SetInput(GradVarName("Out"), {"o0"});
op->SetInput(GradVarName("YOut"), {"y0"});
op->SetInput(GradVarName("ZOut"), {"z0"});
op->SetOutput(GradVarName("X"), {"a0", "a1"});
op->SetOutput(GradVarName("Y"), {"b0"});
op->SetOutput(GradVarName("Z"), {"c0", "c1"});
prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block());
EXPECT_EQ(in_to_outs.size(), 3ul);
std::unordered_map<std::string, std::string> expects = {
{"o0", "a0"}, {"y0", "b0"}, {"z0", "c0"},
};
EXPECT_TRUE(expects == in_to_outs);
}
} // namespace framework
} // namespace paddle
...@@ -52,16 +52,29 @@ bool HasCircleHelper( ...@@ -52,16 +52,29 @@ bool HasCircleHelper(
ir::Node *node, ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list, const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::unordered_set<ir::Node *> *visited, std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace) { std::unordered_set<ir::Node *> *in_trace,
std::vector<std::vector<ir::Node *>> *circles) {
if (visited->find(node) == visited->end()) { if (visited->find(node) == visited->end()) {
visited->insert(node); visited->insert(node);
in_trace->insert(node); in_trace->insert(node);
for (ir::Node *in : adj_list.at(node)) { for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() && if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace)) { HasCircleHelper(in, adj_list, visited, in_trace, circles)) {
return true; return true;
} else if (in_trace->find(in) != in_trace->end()) { } else if (in_trace->find(in) != in_trace->end()) {
if (circles != nullptr) {
std::vector<ir::Node *> circle;
circle.emplace_back(in);
ir::Node *p = in;
for (auto &adj : adj_list.at(p)) {
if (in_trace->count(adj)) {
circle.emplace_back(adj);
p = adj;
}
}
circles->emplace_back(circle);
}
return true; return true;
} }
} }
...@@ -71,11 +84,12 @@ bool HasCircleHelper( ...@@ -71,11 +84,12 @@ bool HasCircleHelper(
} }
bool HasCircleInternal( bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) { const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::vector<std::vector<ir::Node *>> *circles) {
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace; std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) { for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) {
return true; return true;
} }
} }
...@@ -84,13 +98,18 @@ bool HasCircleInternal( ...@@ -84,13 +98,18 @@ bool HasCircleInternal(
} // namespace } // namespace
bool HasCircle(const Graph &graph) { bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph)); return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
}
bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles) {
return HasCircleInternal(BuildOperationAdjList(graph), circles);
} }
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list = std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph); BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list)); PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
......
...@@ -28,6 +28,11 @@ namespace ir { ...@@ -28,6 +28,11 @@ namespace ir {
// Test if the graph contains circle. // Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
// Find All Circles for debugging,
// store all subgraph in circles.
bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles);
size_t GraphNum(const Graph &graph); size_t GraphNum(const Graph &graph);
// Topology Sort the operations in the graph from inputs to outputs. // Topology Sort the operations in the graph from inputs to outputs.
......
...@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) { ...@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) {
// v4->outputs.push_back(o5); // v4->outputs.push_back(o5);
} }
TEST(GraphHelperTest, Circles) {
ProgramDesc prog;
Graph g(prog);
BuildCircleGraph(&g);
std::vector<std::vector<ir::Node*>> circles;
ASSERT_TRUE(FindCircleSubGraph(g, &circles));
ASSERT_EQ(circles.size(), 1UL);
}
TEST(GraphHelperTest, GraphNum) { TEST(GraphHelperTest, GraphNum) {
ProgramDesc prog; ProgramDesc prog;
......
...@@ -38,6 +38,7 @@ struct OpInfo { ...@@ -38,6 +38,7 @@ struct OpInfo {
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
InferInplaceOpFN infer_inplace_;
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
......
...@@ -57,5 +57,8 @@ using InferVarTypeFN = ...@@ -57,5 +57,8 @@ using InferVarTypeFN =
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&, BlockDesc*)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -34,6 +34,6 @@ TEST(Benchmark, PersistToFile) { ...@@ -34,6 +34,6 @@ TEST(Benchmark, PersistToFile) {
benchmark.SetLatency(220); benchmark.SetLatency(220);
benchmark.PersistToFile("1.log"); benchmark.PersistToFile("1.log");
benchmark.PersistToFile("1.log"); benchmark.PersistToFile("2.log");
benchmark.PersistToFile("1.log"); benchmark.PersistToFile("3.log");
} }
...@@ -551,8 +551,10 @@ namespace ops = paddle::operators; ...@@ -551,8 +551,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \ ::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker, \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) ::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
......
...@@ -604,13 +604,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -604,13 +604,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class BatchNormInplaceInToOut : public framework::InplaceInToOut {
public:
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"},
};
return inplace_in_to_out;
}
};
class BatchNormGradInplaceInToOut : public framework::InplaceInToOut {
public:
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
{framework::GradVarName("Y"), framework::GradVarName("X")},
{"SavedMean", framework::GradVarName("Scale")},
{"SavedVariance", framework::GradVarName("Bias")},
};
return inplace_in_to_out;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); ops::BatchNormOpInferVarType, ops::BatchNormGradMaker,
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); ops::BatchNormInplaceInToOut);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp,
ops::BatchNormGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>, batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -18,6 +18,7 @@ namespace ops = paddle::operators; ...@@ -18,6 +18,7 @@ namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out", REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
"X"); "X");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
} }
}; };
class ElementwiseOpInplace : public framework::InplaceInToOut {
public:
using framework::InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \ __ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \ ::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker); \ op_type##GradMaker, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \ REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad) ::paddle::operators::ElementwiseOpExplicitGrad)
...@@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase { ...@@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase {
} }
}; };
class FlattenOpInplaceInToOut : public framework::InplaceInToOut {
public:
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
}
};
class FlattenGradInplaceinToOut : public framework::InplaceInToOut {
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{framework::GradVarName("Out"), framework::GradVarName("X")},
};
return inplace_in_to_out;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -275,10 +304,13 @@ USE_OP(reshape); ...@@ -275,10 +304,13 @@ USE_OP(reshape);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenOpInferShape, ops::FlattenOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>,
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape); ops::FlattenOpInplaceInToOut);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape,
ops::FlattenGradInplaceinToOut);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker); ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker,
ops::FlattenOpInplaceInToOut);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::Flatten2GradInferShape); ops::Flatten2GradInferShape, ops::FlattenGradInplaceinToOut);
...@@ -327,14 +327,45 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -327,14 +327,45 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
} }
}; };
class ReshapeOpInplaceInToOut : public framework::InplaceInToOut {
public:
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
}
};
class ReshapeGradInplaceInToOut : public framework::InplaceInToOut {
using InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{framework::GradVarName("Out"), framework::GradVarName("X")},
};
return inplace_in_to_out;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>,
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); ops::ReshapeOpInplaceInToOut);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel);
...@@ -344,8 +375,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -344,8 +375,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker); ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::ReshapeGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel);
......
...@@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
using ScaleOpInplace = framework::SingleOpInplaceInToOut;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker, REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker,
ops::ScaleOpVarTypeInference); ops::ScaleOpVarTypeInference, ops::ScaleOpInplace);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>, scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>, ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
class SoftmaxInplaceInToOut : public framework::InplaceInToOut {
public:
using framework::InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -1096,6 +1096,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1096,6 +1096,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_early_delete", "memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; }, [](const BuildStrategy &self) { return self.memory_early_delete_; },
[](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; }) [](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; })
.def_property(
"enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; },
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def("_finalize_strategy_and_create_passes", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true); return self.CreatePassesFromStrategy(true);
......
...@@ -158,7 +158,8 @@ def __bootstrap__(): ...@@ -158,7 +158,8 @@ def __bootstrap__():
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'sync_nccl_allreduce', 'limit_of_tmp_allocation', 'sync_nccl_allreduce', 'limit_of_tmp_allocation',
'times_excess_than_required_tmp_allocation' 'times_excess_than_required_tmp_allocation',
'enable_inplace_whitelist'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
...@@ -174,6 +174,11 @@ class CompiledProgram(object): ...@@ -174,6 +174,11 @@ class CompiledProgram(object):
self._exec_strategy.num_threads = cpu_num * 2 self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints trainers_endpoints = self._program._trainers_endpoints
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True
if self._build_strategy.num_trainers > 1 and trainers_endpoints: if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len( assert self._build_strategy.num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)" trainers_endpoints), "num_trainers == len(end_points)"
......
...@@ -1725,6 +1725,19 @@ class Program(object): ...@@ -1725,6 +1725,19 @@ class Program(object):
self._trainers_endpoints = [] self._trainers_endpoints = []
# the distributed lookup table names # the distributed lookup table names
self._distributed_lookup_table = None self._distributed_lookup_table = None
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False
@property
def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs
# maybe same, which conflict with save_inference_model.
return self.__is_mem_optimized
@_is_mem_optimized.setter
def _is_mem_optimized(self, target):
self.__is_mem_optimized = target
@property @property
def op_role(self): def op_role(self):
...@@ -1744,7 +1757,7 @@ class Program(object): ...@@ -1744,7 +1757,7 @@ class Program(object):
return self._current_role return self._current_role
@op_role.setter @op_role.setter
def set_op_role(self, role): def op_role(self, role):
self._current_role = role self._current_role = role
@property @property
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import os import os
import errno import errno
import warnings
import time import time
import shutil import shutil
import six import six
...@@ -931,6 +932,13 @@ def save_inference_model(dirname, ...@@ -931,6 +932,13 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if main_program._is_mem_optimized:
warnings.warn(
"save_inference_model must put before you call memory_optimize. \
the memory_optimize will modify the original program, \
is not suitable for saving inference model \
we save the original program as inference model.",
RuntimeWarning)
# fix the bug that the activation op's output as target will be pruned. # fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance. # will affect the inference performance.
......
...@@ -146,6 +146,9 @@ class ParallelExecutor(object): ...@@ -146,6 +146,9 @@ class ParallelExecutor(object):
# step4: get main_program, scope, local_scopes # step4: get main_program, scope, local_scopes
main = main_program if main_program \ main = main_program if main_program \
else framework.default_main_program() else framework.default_main_program()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
build_strategy.enable_inplace = False if main._is_mem_optimized else True
scope = scope if scope is not None else executor.global_scope() scope = scope if scope is not None else executor.global_scope()
if share_vars_from and not isinstance(share_vars_from, if share_vars_from and not isinstance(share_vars_from,
......
...@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True,
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=False, use_ir_memory_optimize=True,
enable_inplace=True,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
fuse_relu_depthwise_conv=False, fuse_relu_depthwise_conv=False,
optimizer=fluid.optimizer.Adam, optimizer=fluid.optimizer.Adam,
...@@ -60,7 +61,6 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -60,7 +61,6 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed main.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
if optimizer: if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
...@@ -80,7 +80,11 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -80,7 +80,11 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.memory_optimize = use_ir_memory_optimize
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
build_strategy.enable_inplace = False if memory_opt else enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True build_strategy.remove_unnecessary_lock = True
if use_parallel_executor: if use_parallel_executor:
...@@ -100,8 +104,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -100,8 +104,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter): for i in range(iter):
run_executor( run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor( last_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
......
...@@ -25,6 +25,7 @@ import paddle.fluid.layers as layers ...@@ -25,6 +25,7 @@ import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer import paddle.fluid.optimizer as optimizer
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.io import save_inference_model, load_inference_model from paddle.fluid.io import save_inference_model, load_inference_model
from paddle.fluid.transpiler import memory_optimize
class TestBook(unittest.TestCase): class TestBook(unittest.TestCase):
...@@ -87,5 +88,31 @@ class TestBook(unittest.TestCase): ...@@ -87,5 +88,31 @@ class TestBook(unittest.TestCase):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
class TestSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model2"
init_program = Program()
program = Program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
memory_optimize(program, print_log=True)
self.assertEqual(program._is_mem_optimized, True)
# will print warning message
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase
def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(3):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestIrInplace(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _fc_with_batchnorm(self,
ir_memory_optimize,
enable_inplace,
memory_opt=False):
if not core.is_compiled_with_cuda():
return
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=True,
memory_opt=memory_opt,
use_ir_memory_optimize=ir_memory_optimize,
enable_inplace=enable_inplace)
def test_fc_with_batchnorm(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, False)
loss10 = self._fc_with_batchnorm(True, False)
loss01 = self._fc_with_batchnorm(False, True)
loss11 = self._fc_with_batchnorm(True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
...@@ -200,7 +200,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -200,7 +200,7 @@ class TestResnet(TestParallelExecutorBase):
model, model,
use_cuda, use_cuda,
iter=20, iter=20,
delta2=1e-6): delta2=1e-5):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -228,7 +228,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -228,7 +228,7 @@ class TestResnet(TestParallelExecutorBase):
optimizer=optimizer) optimizer=optimizer)
for loss in zip(all_reduce_first_loss, reduce_first_loss): for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-5)
for loss in zip(all_reduce_last_loss, reduce_last_loss): for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2) self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
...@@ -258,17 +258,17 @@ class TestResnet(TestParallelExecutorBase): ...@@ -258,17 +258,17 @@ class TestResnet(TestParallelExecutorBase):
enable_sequential_execution=True) enable_sequential_execution=True)
for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq): for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-5)
for loss in zip(all_reduce_last_loss, all_reduce_last_loss_seq): for loss in zip(all_reduce_last_loss, all_reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2) self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
for loss in zip(reduce_first_loss, reduce_first_loss_seq): for loss in zip(reduce_first_loss, reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-5)
for loss in zip(reduce_last_loss, reduce_last_loss_seq): for loss in zip(reduce_last_loss, reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2) self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
for loss in zip(all_reduce_first_loss_seq, reduce_first_loss_seq): for loss in zip(all_reduce_first_loss_seq, reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-5)
for loss in zip(all_reduce_last_loss_seq, reduce_last_loss_seq): for loss in zip(all_reduce_last_loss_seq, reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2) self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
...@@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda=True, use_cuda=True,
use_reduce=False, use_reduce=False,
iter=20, iter=20,
delta2=1e-6): delta2=1e-5):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase):
optimizer=optimizer) optimizer=optimizer)
self.assertAlmostEquals( self.assertAlmostEquals(
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6) np.mean(parallel_first_loss), single_first_loss[0], delta=1e-5)
self.assertAlmostEquals( self.assertAlmostEquals(
np.mean(parallel_last_loss), single_last_loss[0], delta=delta2) np.mean(parallel_last_loss), single_last_loss[0], delta=delta2)
......
...@@ -540,6 +540,7 @@ def memory_optimize(input_program, ...@@ -540,6 +540,7 @@ def memory_optimize(input_program,
if skip_opt_set is not None: if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set)) skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program._is_mem_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
...@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None): ...@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
None None
""" """
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program._is_mem_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.release_memory(skip_opt_set=skip_opt_set) cfg.release_memory(skip_opt_set=skip_opt_set)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册