未验证 提交 4e1bc6e8 编写于 作者: Z Zeng Jinle 提交者: GitHub

Rewrite inplace pass and fix gc bug (#17126)

* fix op graph view
test=develop

* rewrite inplace pass and fix reference count pass bug
test=develop

* fix unittest failed
test=develop

* follow comments, test=develop
上级 08773b60
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,20 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include <algorithm>
#include <deque>
#include <iterator>
#include <memory>
#include <map>
#include <queue>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
......@@ -56,6 +50,10 @@ DEFINE_bool(
DECLARE_string(memory_optimize_debug);
namespace paddle {
namespace framework {
namespace details {
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
"sigmoid",
......@@ -83,488 +81,376 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT
// but the static size during compiling time would be wrong.
// Use a flag to indicate such ops. Please fix me when found a better way.
static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT
"reshape2"
"reshape2", "reshape2_grad"
};
// clang-format on
namespace paddle {
namespace framework {
namespace details {
class InplacePass : public ir::Pass {
public:
InplacePass();
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr;
for (auto* next_op : var->outputs) {
for (auto* output : next_op->outputs) {
if (output->IsVar() && !output->IsCtrlVar() &&
output->Name() == var->Name()) {
inplaced_var = output;
}
protected:
void ApplyImpl(ir::Graph *graph) const override;
private:
// Collect vars that cannot be reused
// e.g.: subblock ops in/out, distributed ops in/out, op_role_var
void CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const;
// Check whether var_name should be skipped
bool IsSkipVar(const std::string &var_name) const;
// Rename out with name of in, and guarantee that the graph is
// still a SSA graph
void RenameInOut(ir::Node *op, ir::Node *in, ir::Node *out) const;
// Check whether var is the last version one in SSA graph
bool IsLastVersionVar(ir::Node *var) const;
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find node whose name is equal to the given name
static ir::Node *FindNodeByName(const std::string &name,
const std::vector<ir::Node *> &nodes);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
private:
// SSA graph. var_name -> each version of vars
mutable std::map<std::string, std::vector<ir::Node *>> ssa_map_;
// Skip vars, including subblock ops in/out, distributed ops in/out,
// op_role_var
mutable std::unordered_set<std::string> skip_vars_;
// Op whitelist which should not peform inplace
// Only enabled when FLAGS_enable_inplace_whitelist is true.
mutable std::unordered_set<std::string> whitelist_ops_;
};
InplacePass::InplacePass() {
if (FLAGS_enable_inplace_whitelist) {
for (auto &s : kInplacedOpWhiteList) {
whitelist_ops_.emplace(s);
}
}
return inplaced_var;
}
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
if (var->inputs.empty()) return nullptr;
auto* prev_op = var->inputs.at(0);
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
[&](ir::Node* node) {
if (node->IsVar() && !node->IsCtrlVar() &&
node->Name() == var->Name()) {
return true;
} else {
return false;
}
});
return input_it == prev_op->inputs.end() ? nullptr : *input_it;
std::vector<ir::Node *> *InplacePass::AllVersionVars(
const std::string &var_name) const {
auto iter = ssa_map_.find(var_name);
PADDLE_ENFORCE(iter != ssa_map_.end(), "cannot find var %s in ssa graph",
var_name);
PADDLE_ENFORCE(!iter->second.empty(), "var %s is empty in ssa graph",
var_name);
return &(iter->second);
}
InplacePass::InplacePass() : Pass() {
if (FLAGS_enable_inplace_whitelist) {
for (auto& s : kInplacedOpWhiteList) {
whitelist_.emplace(s);
bool InplacePass::IsSkipVar(const std::string &var_name) const {
return skip_vars_.count(var_name) > 0;
}
bool InplacePass::IsLastVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->back() == var;
}
bool InplacePass::CheckOpDeps(ir::Node *op,
const std::vector<ir::Node *> &ops) const {
std::unordered_set<ir::Node *> other_ops(ops.begin(), ops.end());
other_ops.erase(op);
if (other_ops.empty()) return true;
// Traverse all preceding ops of op
std::queue<ir::Node *> queue;
std::unordered_set<ir::Node *> visited_ops;
queue.push(op);
visited_ops.insert(op);
// Visit all preceding ops of `op`, and erase it from other_ops if it is
// inside other_ops. Return true only if other_ops is empty(), which means
// that all `ops` are preceding ops of `op`.
while (!queue.empty()) {
auto *cur_op = queue.front();
queue.pop();
for (auto *in_var : cur_op->inputs) {
for (auto *in_op : in_var->inputs) {
if (visited_ops.count(in_op) != 0) {
continue;
}
visited_ops.insert(in_op);
queue.push(in_op);
other_ops.erase(in_op);
if (other_ops.empty()) return true;
}
}
}
return false;
}
void InplacePass::InitSSAGraphNodes() const {
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars;
for (auto* op : view_.AllOps()) {
for (auto* node : op->inputs) {
if (!node->IsVar() || node->IsCtrlVar()) continue;
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
"Graph should have attr %s", details::kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var);
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other's name.
// Also check the ops which has sub-block
auto update_skip_set = [&](ir::Node *node) {
for (auto &in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_vars_.emplace(in->Name());
}
for (auto* node : op->outputs) {
if (!node->IsVar() || node->IsCtrlVar()) continue;
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
}
for (auto &out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) {
skip_vars_.emplace(out->Name());
}
}
}
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph);
InitSSAGraphNodes();
};
auto cnt = 0;
for (auto* op : view_.AllOps()) {
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
for (auto *node : ops) {
if (!node->IsOp()) continue;
// avoid optimizing the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) {
update_skip_set(node);
continue;
TryInplaceOpInputOutput(op, graph);
}
}
void InplacePass::InplaceModifyDesc(const std::string& var,
const std::string& cache_var,
const size_t& idx) const {
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
ir::Node* op = view_.AllOps()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
op_desc->RenameOutput(var, cache_var);
op_desc->Flush();
auto node_name = node->Name();
if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
update_skip_set(node);
}
}
}
const NodeSwapQueue InplacePass::TryInplaceModifyVar(
const std::string& var, const std::string& cache_var, const size_t& idx,
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
NodeSwapQueue swap_nodes;
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
ir::Node *out_var) const {
auto out_var_name = out_var->Name();
auto in_var_name = in_var->Name();
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
auto &all_out_nodes = *AllVersionVars(out_var_name);
auto &all_in_nodes = *AllVersionVars(in_var_name);
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
auto iter = std::find(all_out_nodes.begin(), all_out_nodes.end(), out_var);
PADDLE_ENFORCE(iter != all_out_nodes.end(), "Cannot find out var %s",
out_var_name);
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes.emplace_back(std::make_pair(node, cache_node));
// The following codes are designed to guarantee that ssa_map_ is still
// an ssa graph after inplace is performed.
// Step 1: Rename the following versions of out_var as the name of in_var
// Step 2: Remove the following versions of out_var and append them to in_var
// Be careful that the inputs of input op of out_var should not be renamed,
// but outputs should be renamed.
auto original_iter = iter;
while (iter != all_out_nodes.end()) {
auto *node = *iter;
/* Step 1 */
node->RenameVar(in_var_name);
if (iter != original_iter) {
for (auto *in : node->inputs) {
if (in->IsOp() && in->Op()) {
in->Op()->RenameOutput(out_var_name, in_var_name);
in->Op()->RenameInput(out_var_name, in_var_name);
in->Op()->Flush();
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes.emplace_back(std::make_pair(node, cache_node));
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()) {
out->Op()->RenameOutput(out_var_name, in_var_name);
out->Op()->RenameInput(out_var_name, in_var_name);
out->Op()->Flush();
}
}
/* Step 2 */
all_in_nodes.emplace_back(node);
++iter;
}
return swap_nodes;
}
/* Step 2 */
all_out_nodes.erase(original_iter, all_out_nodes.end());
void InplacePass::CommitModify(const NodeSwapQueue& swap_nodes,
ir::Graph* graph) const {
for (auto& pair : swap_nodes) {
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
if (all_out_nodes.empty()) {
ssa_map_.erase(out_var_name);
}
op->Op()->RenameOutput(out_var_name, in_var_name);
op->Op()->Flush();
}
void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
ir::Graph* graph) const {
for (auto& pair : nodes) {
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
ir::Node *InplacePass::FindNodeByName(const std::string &name,
const std::vector<ir::Node *> &nodes) {
ir::Node *found_node = nullptr;
for (auto *node : nodes) {
if (node->Name() == name) {
PADDLE_ENFORCE(found_node == nullptr, "Find duplicate input nodes %s",
name);
found_node = node;
}
graph->RemoveNode(cache_node);
}
return found_node;
}
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const {
VLOG(4) << "Try to inplace op " << op->Name();
// some pre-requirments need to meet if the op want to inplaced.
PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr");
auto* op_desc = op->Op();
auto& infer_inplace =
OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_;
void InplacePass::ApplyImpl(ir::Graph *graph) const {
// Step 1: topo sort ops, collect skip vars
auto ops = ir::TopologySortOperations(*graph);
CollectSkipVars(graph, ops);
// 1. infer_inplace_ is registered.
if (!static_cast<bool>(infer_inplace)) return;
PADDLE_ENFORCE(static_cast<bool>(infer_inplace),
"%s's infer_inplace has not been registered", op_desc->Type());
// Step 2: build ssa var map
for (auto *op_node : ops) {
for (auto *in : op_node->inputs) {
PADDLE_ENFORCE(in->IsVar());
// Only create a new var node when var first occurs in input of op.
if (ssa_map_.count(in->Name()) == 0) {
ssa_map_[in->Name()].emplace_back(in);
}
}
auto in_to_outs = infer_inplace(*op_desc);
// Always create a new var node for each output of op.
for (auto *out : op_node->outputs) {
PADDLE_ENFORCE(out->IsVar());
ssa_map_[out->Name()].emplace_back(out);
}
}
auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor);
// Step 3: traverse ops and try inplace if possible
for (auto *op_node : ops) {
PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr");
for (auto& pair : in_to_outs) {
auto& in_para_name = pair.first;
auto& out_para_name = pair.second;
auto *op_desc = op_node->Op();
auto op_type = op_desc->Type();
auto input_vars = op->Op()->Input(in_para_name);
if (!input_vars.size()) {
VLOG(4) << "Parameter " << in_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
// Skip op inside whitelist
if (whitelist_ops_.count(op_type) > 0) {
continue;
}
auto output_vars = op->Op()->Output(out_para_name);
if (!output_vars.size()) {
VLOG(4) << "Parameter " << out_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
continue;
}
auto in_var_name = input_vars.at(0);
auto out_var_name = output_vars.at(0);
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name;
if (view_.InSkipSet(in_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name);
continue;
}
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (view_.InSkipSet(out_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name);
if (!infer_inplace) {
continue;
}
if (var_nodes_[in_var_name].back() != in_node) {
VLOG(4) << "SKIP since " << in_var_name
<< " is also used as output by other ops";
auto in_to_outs = infer_inplace(*op_desc);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &out_param = pair.second;
auto &in_args = op_desc->Input(in_param);
auto &out_args = op_desc->Output(out_param);
if (in_args.empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param
<< ") is empty in " << op_type;
continue;
}
bool can_replace = true;
if (in_var_name == out_var_name) {
can_replace = false;
VLOG(4) << "SKIP: Input variable " << in_var_name << " & Output variable "
<< out_var_name << " are the same";
} else if (!NodeCanReused(in_node)) {
can_replace = false;
VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused";
} else if (!NodeCanReused(out_node)) {
can_replace = false;
VLOG(4) << "SKIP: Output variable " << out_var_name
<< " cannot be reused";
} else if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
can_replace = false;
VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType()
<< " does not match Output type : " << out_node->Var()->GetType();
} else if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
can_replace = false;
VLOG(4) << "SKIP: Input and Output varialbe size not match";
if (out_args.empty()) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ") is empty in " << op_type;
continue;
}
if (!can_replace) continue;
auto &in_arg = in_args[0];
auto &out_arg = out_args[0];
// 2. If the variable is the input of muliple ops, we need to make sure
// current op has dependecny on other ops use the same variable
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory.",
out_var_name, in_var_name, op->Name());
if (IsSkipVar(in_arg)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is skipped in " << op_type;
continue;
}
// Debug Interface. Which would be skipped by the pass.
if (out_node->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug="
<< out_node->Name();
if (IsSkipVar(out_arg)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is skipped in " << op_type;
continue;
}
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
InplaceModifyDesc(out_var_name, in_var_name, idx);
CommitModify(swap_nodes, graph);
} else {
VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name());
WithdrawModify(swap_nodes, graph);
}
if (in_arg == out_arg) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is the same with Output(" << out_param << ")=" << out_arg
<< " in " << op_type;
continue;
}
}
void GraphView::TopoSort(ir::Graph* graph) {
//
ops_.clear();
auto deps_num = [](ir::Node* op) {
auto cnt = 0;
for (auto& var : op->inputs)
if (var->inputs.size() > 0) ++cnt;
return cnt;
};
std::queue<std::pair<ir::Node*, uint32_t>> ready_ops;
auto *in_node = FindNodeByName(in_arg, op_node->inputs);
PADDLE_ENFORCE_NOT_NULL(in_node, "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
int level = 0;
auto nodes = graph->Nodes();
std::unordered_map<ir::Node*, uint32_t> deps_map;
for (auto& node : nodes) {
if (node->IsOp() && node->Op() != nullptr) {
deps_map[node] = deps_num(node);
if (0 == deps_map[node]) {
ready_ops.push({node, level});
}
}
if (!NodeCanReused(in_node)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not reusable in " << op_type;
continue;
}
while (!ready_ops.empty()) {
auto item = ready_ops.front();
ready_ops.pop();
ops_.emplace_back(item.first);
// record level when pop from queue
op_level_[item.first] = item.second;
for (auto node : item.first->outputs) {
for (auto op : node->outputs) {
--deps_map[op];
if (deps_map[op] == 0) ready_ops.push({op, item.second + 1});
}
}
if (!IsLastVersionVar(in_node)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the last version in " << op_type;
continue;
}
bool all_ops_checked = true;
for (auto& node : nodes) {
if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) {
all_ops_checked = false;
LOG(WARNING)
<< "Node " << node->Name() << " has not been checked. "
<< "Maybe some passes have not handle node dependency rightly.";
break;
}
// If in_node is used as inputs of many ops, check whether all of that ops
// depends on op_node. If not, in_node cannot be inplaced.
if (in_node->outputs.size() > 1 &&
!CheckOpDeps(op_node, in_node->outputs)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not lastly used in " << op_type;
continue;
}
PADDLE_ENFORCE(all_ops_checked, "All ops deps should be 0 after analysis");
}
auto *out_node = FindNodeByName(out_arg, op_node->outputs);
PADDLE_ENFORCE_NOT_NULL(out_node,
"Output(%s)=%s cannot be found in op %s",
out_param, out_arg, op_type);
// return true if current op node depeneds on all other op that use the same
// variable node
bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const {
// get op list that rely on the same variable
auto op_list = var->outputs;
for (auto& op : op_list) {
if (op == current_op) continue;
VLOG(4) << " GraphView::CheckDeps : " << op->Name() << " & "
<< current_op->Name();
if (!CheckOpDeps(op, current_op)) return false;
VLOG(4) << "";
if (!NodeCanReused(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is not reusable in " << op_type;
continue;
}
return true;
}
// check if op2 depends on op1's output
bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
if (VLOG_IS_ON(4)) {
auto print_op = [&](ir::Node* op, const char* name) {
std::ostringstream os;
os << " " << name << " : " << op->Name() << " ";
os << "Input args : ";
for (auto& arg : op->inputs) os << arg->Name() << " ";
os << "Output args : ";
for (auto& arg : op->outputs) os << arg->Name() << " ";
os << "Level : " << op_level_.at(op);
VLOG(4) << os.str();
};
print_op(op1, "OP1");
print_op(op2, "OP2");
if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same type with "
<< "Output(" << out_param << ")=" << out_arg << " in "
<< op_type;
continue;
}
if (op1 == op2) return true;
if (op_level_.at(op1) >= op_level_.at(op2)) return false;
for (auto& var : op2->inputs)
if (var->inputs.size() > 0 && CheckOpDeps(op1, var->inputs[0])) return true;
return false;
}
ir::Node* GraphView::GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const {
// nodes should be op->inputs/outputs
// node in same node do have different name.
std::unordered_set<std::string> nodes_in_op;
bool has_dup_node =
std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) {
if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) {
if (nodes_in_op.count(node->Name())) return true;
nodes_in_op.emplace(node->Name());
if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
<< "Output(" << out_param << ")=" << out_arg << " in "
<< op_type;
continue;
}
return false;
});
PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!");
ir::Node* node = nullptr;
for (auto* it : nodes) {
if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue;
if (it->Name() == name) {
node = it;
break;
}
}
PADDLE_ENFORCE(node != nullptr,
string::Sprintf("Not found var %s in nodes!", name));
return node;
}
std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
// get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std::vector<ir::Node*> pending_ops;
ir::Node* p = node;
while (p != nullptr) {
pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end());
p = GetPrevCascadeInplacedVar(p);
}
return pending_ops;
}
void GraphView::Build(ir::Graph* g) {
// track the var nodes in correct order.
// Because we insert some new created node. Which may have data race between
// nodes.
// resolve data harzards depends on the var nodes in right order.
TopoSort(g);
// fill the skip_set_
PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars));
auto& mem_opt_whitelist = g->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
auto update_skip_set = [&](ir::Node* node) {
for (auto& in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_set_.emplace(in->Name());
}
}
for (auto& out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name());
// Debug Interface. Which would be skipped by the pass.
if (out_arg == FLAGS_memory_optimize_debug) {
VLOG(4) << "Skiped var by force. FLAGS_memory_optimize_debug="
<< out_node->Name();
continue;
}
};
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
// avoid optimize the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) update_skip_set(node);
if (node->Name() == "send") update_skip_set(node);
if (node->Name() == "recv") update_skip_set(node);
if (node->Name() == "prefetch") update_skip_set(node);
VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name()
<< " in " << op_type;
RenameInOut(op_node, in_node, out_node);
}
}
}
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::InSkipSet(const std::string& var) const {
return skip_set_.count(var);
}
} // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class GraphView {
public:
GraphView() = default;
void Build(ir::Graph* g);
const std::vector<ir::Node*>& AllOps();
ir::Node* GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const;
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future.
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
bool CheckDeps(ir::Node* var, ir::Node* current_op) const;
bool CheckOpDeps(ir::Node* op1, ir::Node* op2) const;
void TopoSort(ir::Graph* g);
private:
std::vector<ir::Node*> ops_;
std::unordered_set<std::string> skip_set_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
std::unordered_map<ir::Node*, uint32_t> op_level_;
};
// swap pairs in sequence
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
class InplacePass : public ir::Pass {
public:
InplacePass();
protected:
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
private:
const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
mutable std::unordered_set<std::string> whitelist_;
mutable GraphView view_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -56,7 +56,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
q.push(op);
do {
while (!q.empty()) {
op = q.front();
q.pop();
for (auto &pending_op : pending_ops_.at(op)) {
......@@ -65,9 +65,10 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
if (!callback(pending_op)) {
return false;
}
q.push(pending_op);
}
}
}
} while (!q.empty());
return true;
}
......
......@@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor {
const OpGraphView graph_;
};
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
} while (!q.empty());
return nullptr;
}
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
/**
* Shrink op dependencies according to no need buffer vars.
*
......@@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
}
return nullptr;
}
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
return {};
}
PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");
// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
}
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
......@@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
}
......@@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second;
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
bool ok;
auto result =
ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok);
VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);
// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (!ok) break;
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
if (status == LastLiveOpSearchStatus::kFailure) {
break;
}
size_t original_op_deps = result.size();
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
if (status == LastLiveOpSearchStatus::kShouldPrecede) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}
size_t final_op_deps = result.size();
if (final_op_deps < original_op_deps) {
VLOG(5) << "Shrink op deps from " << original_op_deps << " to "
<< final_op_deps;
}
PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
......
......@@ -18,7 +18,6 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -27,9 +26,15 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
USE_PASS(inplace_pass);
namespace paddle {
namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
return ir::PassRegistry::Instance().Get("inplace_pass");
}
class NOP : public OperatorBase {
public:
NOP(const std::string& type, const VariableNameMap& inputs,
......@@ -202,7 +207,7 @@ ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) {
std::unique_ptr<ir::Graph> test_SingleOpInplaceInToOut(
std::unique_ptr<ir::Graph> g) {
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_NE(op_node, nullptr);
pass->Apply(g.get());
......@@ -268,7 +273,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
ASSERT_TRUE(op_node != nullptr);
......@@ -304,7 +309,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
ASSERT_TRUE(op_node != nullptr);
......
......@@ -108,11 +108,18 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_,
"Must be type of variable");
name_ = new_name;
var_desc_->SetName(new_name);
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
const std::string name_;
std::string name_;
std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_;
Type type_;
......
......@@ -220,16 +220,6 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
class SoftmaxInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
}
};
} // namespace operators
} // namespace paddle
......
......@@ -74,3 +74,7 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册