未验证 提交 4e1bc6e8 编写于 作者: Z Zeng Jinle 提交者: GitHub

Rewrite inplace pass and fix gc bug (#17126)

* fix op graph view
test=develop

* rewrite inplace pass and fix reference count pass bug
test=develop

* fix unittest failed
test=develop

* follow comments, test=develop
上级 08773b60
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,20 +12,14 @@ ...@@ -12,20 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/inplace_op_pass.h" #include <map>
#include <algorithm>
#include <deque>
#include <iterator>
#include <memory>
#include <queue> #include <queue>
#include <sstream>
#include <stack>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space. // NOTE(dzhwinter): inplace means one op output variable reuse the input space.
...@@ -56,6 +50,10 @@ DEFINE_bool( ...@@ -56,6 +50,10 @@ DEFINE_bool(
DECLARE_string(memory_optimize_debug); DECLARE_string(memory_optimize_debug);
namespace paddle {
namespace framework {
namespace details {
// clang-format off // clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT const std::string kInplacedOpWhiteList[] = { // NOLINT
"sigmoid", "sigmoid",
...@@ -83,490 +81,378 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT ...@@ -83,490 +81,378 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT
// but the static size during compiling time would be wrong. // but the static size during compiling time would be wrong.
// Use a flag to indicate such ops. Please fix me when found a better way. // Use a flag to indicate such ops. Please fix me when found a better way.
static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT
"reshape2" "reshape2", "reshape2_grad"
}; };
// clang-format on // clang-format on
namespace paddle { class InplacePass : public ir::Pass {
namespace framework { public:
namespace details { InplacePass();
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { protected:
// if next op is inplaced, then return the output var void ApplyImpl(ir::Graph *graph) const override;
// otherwise return nullptr
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); private:
ir::Node* inplaced_var = nullptr; // Collect vars that cannot be reused
for (auto* next_op : var->outputs) { // e.g.: subblock ops in/out, distributed ops in/out, op_role_var
for (auto* output : next_op->outputs) { void CollectSkipVars(ir::Graph *graph,
if (output->IsVar() && !output->IsCtrlVar() && const std::vector<ir::Node *> &ops) const;
output->Name() == var->Name()) {
inplaced_var = output; // Check whether var_name should be skipped
} bool IsSkipVar(const std::string &var_name) const;
// Rename out with name of in, and guarantee that the graph is
// still a SSA graph
void RenameInOut(ir::Node *op, ir::Node *in, ir::Node *out) const;
// Check whether var is the last version one in SSA graph
bool IsLastVersionVar(ir::Node *var) const;
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find node whose name is equal to the given name
static ir::Node *FindNodeByName(const std::string &name,
const std::vector<ir::Node *> &nodes);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
private:
// SSA graph. var_name -> each version of vars
mutable std::map<std::string, std::vector<ir::Node *>> ssa_map_;
// Skip vars, including subblock ops in/out, distributed ops in/out,
// op_role_var
mutable std::unordered_set<std::string> skip_vars_;
// Op whitelist which should not peform inplace
// Only enabled when FLAGS_enable_inplace_whitelist is true.
mutable std::unordered_set<std::string> whitelist_ops_;
};
InplacePass::InplacePass() {
if (FLAGS_enable_inplace_whitelist) {
for (auto &s : kInplacedOpWhiteList) {
whitelist_ops_.emplace(s);
} }
} }
return inplaced_var;
} }
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { std::vector<ir::Node *> *InplacePass::AllVersionVars(
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); const std::string &var_name) const {
if (var->inputs.empty()) return nullptr; auto iter = ssa_map_.find(var_name);
auto* prev_op = var->inputs.at(0); PADDLE_ENFORCE(iter != ssa_map_.end(), "cannot find var %s in ssa graph",
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(), var_name);
[&](ir::Node* node) { PADDLE_ENFORCE(!iter->second.empty(), "var %s is empty in ssa graph",
if (node->IsVar() && !node->IsCtrlVar() && var_name);
node->Name() == var->Name()) { return &(iter->second);
return true;
} else {
return false;
}
});
return input_it == prev_op->inputs.end() ? nullptr : *input_it;
} }
InplacePass::InplacePass() : Pass() { bool InplacePass::IsSkipVar(const std::string &var_name) const {
if (FLAGS_enable_inplace_whitelist) { return skip_vars_.count(var_name) > 0;
for (auto& s : kInplacedOpWhiteList) { }
whitelist_.emplace(s);
bool InplacePass::IsLastVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->back() == var;
}
bool InplacePass::CheckOpDeps(ir::Node *op,
const std::vector<ir::Node *> &ops) const {
std::unordered_set<ir::Node *> other_ops(ops.begin(), ops.end());
other_ops.erase(op);
if (other_ops.empty()) return true;
// Traverse all preceding ops of op
std::queue<ir::Node *> queue;
std::unordered_set<ir::Node *> visited_ops;
queue.push(op);
visited_ops.insert(op);
// Visit all preceding ops of `op`, and erase it from other_ops if it is
// inside other_ops. Return true only if other_ops is empty(), which means
// that all `ops` are preceding ops of `op`.
while (!queue.empty()) {
auto *cur_op = queue.front();
queue.pop();
for (auto *in_var : cur_op->inputs) {
for (auto *in_op : in_var->inputs) {
if (visited_ops.count(in_op) != 0) {
continue;
}
visited_ops.insert(in_op);
queue.push(in_op);
other_ops.erase(in_op);
if (other_ops.empty()) return true;
}
} }
} }
return false;
} }
void InplacePass::InitSSAGraphNodes() const { void InplacePass::CollectSkipVars(ir::Graph *graph,
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars; const std::vector<ir::Node *> &ops) const {
for (auto* op : view_.AllOps()) { // 1. Collect op role vars
for (auto* node : op->inputs) { PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
if (!node->IsVar() || node->IsCtrlVar()) continue; "Graph should have attr %s", details::kMemOptSkipVars);
if (all_vars[node->Name()].count(node) == 0) { auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
all_vars[node->Name()].emplace(node); for (const auto &var : mem_opt_whitelist) {
var_nodes_[node->Name()].emplace_back(node); skip_vars_.emplace(var);
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other's name.
// Also check the ops which has sub-block
auto update_skip_set = [&](ir::Node *node) {
for (auto &in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_vars_.emplace(in->Name());
} }
} }
for (auto* node : op->outputs) { for (auto &out : node->outputs) {
if (!node->IsVar() || node->IsCtrlVar()) continue; if (out->IsVar() && out->Var() != nullptr) {
if (all_vars[node->Name()].count(node) == 0) { skip_vars_.emplace(out->Name());
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
} }
} }
} };
}
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph);
InitSSAGraphNodes();
auto cnt = 0; for (auto *node : ops) {
for (auto* op : view_.AllOps()) { if (!node->IsOp()) continue;
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name(); // avoid optimizing the variable used in sub-blocks
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) if (OpHasSubBlock(node->Op())) {
update_skip_set(node);
continue; continue;
TryInplaceOpInputOutput(op, graph); }
}
}
void InplacePass::InplaceModifyDesc(const std::string& var, auto node_name = node->Name();
const std::string& cache_var, if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
const size_t& idx) const { update_skip_set(node);
for (size_t i = idx; i < view_.AllOps().size(); ++i) { }
ir::Node* op = view_.AllOps()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
op_desc->RenameOutput(var, cache_var);
op_desc->Flush();
} }
} }
const NodeSwapQueue InplacePass::TryInplaceModifyVar( void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
const std::string& var, const std::string& cache_var, const size_t& idx, ir::Node *out_var) const {
ir::Graph* graph) const { auto out_var_name = out_var->Name();
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && auto in_var_name = in_var->Name();
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); auto &all_out_nodes = *AllVersionVars(out_var_name);
var_desc->SetName(cache_var); auto &all_in_nodes = *AllVersionVars(in_var_name);
NodeSwapQueue swap_nodes; auto iter = std::find(all_out_nodes.begin(), all_out_nodes.end(), out_var);
PADDLE_ENFORCE(iter != all_out_nodes.end(), "Cannot find out var %s",
for (size_t i = idx; i < view_.AllOps().size(); ++i) { out_var_name);
auto* op = view_.AllOps()[i];
// The following codes are designed to guarantee that ssa_map_ is still
// redirect the input to the latest version of cache_var // an ssa graph after inplace is performed.
for (auto* node : op->inputs) { // Step 1: Rename the following versions of out_var as the name of in_var
if (node->Name() == var) { // Step 2: Remove the following versions of out_var and append them to in_var
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); // Be careful that the inputs of input op of out_var should not be renamed,
// but outputs should be renamed.
// swap node to cache_node auto original_iter = iter;
cache_node->outputs.insert(cache_node->outputs.end(), while (iter != all_out_nodes.end()) {
node->outputs.begin(), node->outputs.end()); auto *node = *iter;
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); /* Step 1 */
auto* prev_op = node->inputs[0]; node->RenameVar(in_var_name);
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, if (iter != original_iter) {
cache_node); for (auto *in : node->inputs) {
cache_node->inputs.emplace_back(prev_op); if (in->IsOp() && in->Op()) {
for (auto* next_op : node->outputs) { in->Op()->RenameOutput(out_var_name, in_var_name);
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, in->Op()->RenameInput(out_var_name, in_var_name);
cache_node); in->Op()->Flush();
} }
swap_nodes.emplace_back(std::make_pair(node, cache_node));
} }
} }
// if we need to rename the output, for (auto *out : node->outputs) {
// always create a newer version of cache_var if (out->IsOp() && out->Op()) {
for (auto* node : op->outputs) { out->Op()->RenameOutput(out_var_name, in_var_name);
if (node->Name() == var) { out->Op()->RenameInput(out_var_name, in_var_name);
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); out->Op()->Flush();
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes.emplace_back(std::make_pair(node, cache_node));
} }
} }
/* Step 2 */
all_in_nodes.emplace_back(node);
++iter;
} }
return swap_nodes; /* Step 2 */
} all_out_nodes.erase(original_iter, all_out_nodes.end());
void InplacePass::CommitModify(const NodeSwapQueue& swap_nodes, if (all_out_nodes.empty()) {
ir::Graph* graph) const { ssa_map_.erase(out_var_name);
for (auto& pair : swap_nodes) {
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
} }
op->Op()->RenameOutput(out_var_name, in_var_name);
op->Op()->Flush();
} }
void InplacePass::WithdrawModify(const NodeSwapQueue& nodes, ir::Node *InplacePass::FindNodeByName(const std::string &name,
ir::Graph* graph) const { const std::vector<ir::Node *> &nodes) {
for (auto& pair : nodes) { ir::Node *found_node = nullptr;
auto *node = pair.first, *cache_node = pair.second; for (auto *node : nodes) {
const std::string var = node->Name(), cache_var = cache_node->Name(); if (node->Name() == name) {
auto* prev_op = node->inputs[0]; PADDLE_ENFORCE(found_node == nullptr, "Find duplicate input nodes %s",
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, name);
node); found_node = node;
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
} }
graph->RemoveNode(cache_node);
} }
return found_node;
} }
void InplacePass::TryInplaceOpInputOutput(ir::Node* op, void InplacePass::ApplyImpl(ir::Graph *graph) const {
ir::Graph* graph) const { // Step 1: topo sort ops, collect skip vars
VLOG(4) << "Try to inplace op " << op->Name(); auto ops = ir::TopologySortOperations(*graph);
// some pre-requirments need to meet if the op want to inplaced. CollectSkipVars(graph, ops);
PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr");
// Step 2: build ssa var map
auto* op_desc = op->Op(); for (auto *op_node : ops) {
auto& infer_inplace = for (auto *in : op_node->inputs) {
OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; PADDLE_ENFORCE(in->IsVar());
// Only create a new var node when var first occurs in input of op.
// 1. infer_inplace_ is registered. if (ssa_map_.count(in->Name()) == 0) {
if (!static_cast<bool>(infer_inplace)) return; ssa_map_[in->Name()].emplace_back(in);
PADDLE_ENFORCE(static_cast<bool>(infer_inplace), }
"%s's infer_inplace has not been registered", op_desc->Type());
auto in_to_outs = infer_inplace(*op_desc);
auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor);
for (auto& pair : in_to_outs) {
auto& in_para_name = pair.first;
auto& out_para_name = pair.second;
auto input_vars = op->Op()->Input(in_para_name);
if (!input_vars.size()) {
VLOG(4) << "Parameter " << in_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
continue;
}
auto output_vars = op->Op()->Output(out_para_name);
if (!output_vars.size()) {
VLOG(4) << "Parameter " << out_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
continue;
}
auto in_var_name = input_vars.at(0);
auto out_var_name = output_vars.at(0);
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name;
if (view_.InSkipSet(in_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name);
continue;
}
if (view_.InSkipSet(out_var_name)) {
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name);
continue;
} }
if (var_nodes_[in_var_name].back() != in_node) { // Always create a new var node for each output of op.
VLOG(4) << "SKIP since " << in_var_name for (auto *out : op_node->outputs) {
<< " is also used as output by other ops"; PADDLE_ENFORCE(out->IsVar());
continue; ssa_map_[out->Name()].emplace_back(out);
} }
}
bool can_replace = true; // Step 3: traverse ops and try inplace if possible
if (in_var_name == out_var_name) { for (auto *op_node : ops) {
can_replace = false; PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr");
VLOG(4) << "SKIP: Input variable " << in_var_name << " & Output variable "
<< out_var_name << " are the same";
} else if (!NodeCanReused(in_node)) {
can_replace = false;
VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused";
} else if (!NodeCanReused(out_node)) {
can_replace = false;
VLOG(4) << "SKIP: Output variable " << out_var_name
<< " cannot be reused";
} else if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
can_replace = false;
VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType()
<< " does not match Output type : " << out_node->Var()->GetType();
} else if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
can_replace = false;
VLOG(4) << "SKIP: Input and Output varialbe size not match";
}
if (!can_replace) continue; auto *op_desc = op_node->Op();
auto op_type = op_desc->Type();
// 2. If the variable is the input of muliple ops, we need to make sure // Skip op inside whitelist
// current op has dependecny on other ops use the same variable if (whitelist_ops_.count(op_type) > 0) {
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory.",
out_var_name, in_var_name, op->Name());
continue; continue;
} }
// Debug Interface. Which would be skipped by the pass. auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (out_node->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" if (!infer_inplace) {
<< out_node->Name();
continue; continue;
} }
// NOTE(dzhwinter): auto in_to_outs = infer_inplace(*op_desc);
// two stage commit of inplaced process. if after inplace happens generate a for (auto &pair : in_to_outs) {
// circle, auto &in_param = pair.first;
// then withdraw the changes. Otherwise, safely add the node. auto &out_param = pair.second;
auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
InplaceModifyDesc(out_var_name, in_var_name, idx);
CommitModify(swap_nodes, graph);
} else {
VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name());
WithdrawModify(swap_nodes, graph);
}
}
}
void GraphView::TopoSort(ir::Graph* graph) { auto &in_args = op_desc->Input(in_param);
// auto &out_args = op_desc->Output(out_param);
ops_.clear();
auto deps_num = [](ir::Node* op) {
auto cnt = 0;
for (auto& var : op->inputs)
if (var->inputs.size() > 0) ++cnt;
return cnt;
};
std::queue<std::pair<ir::Node*, uint32_t>> ready_ops; if (in_args.empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param
<< ") is empty in " << op_type;
continue;
}
int level = 0; if (out_args.empty()) {
auto nodes = graph->Nodes(); VLOG(4) << "Cannot inplace because Output(" << out_param
std::unordered_map<ir::Node*, uint32_t> deps_map; << ") is empty in " << op_type;
for (auto& node : nodes) { continue;
if (node->IsOp() && node->Op() != nullptr) {
deps_map[node] = deps_num(node);
if (0 == deps_map[node]) {
ready_ops.push({node, level});
} }
}
}
while (!ready_ops.empty()) { auto &in_arg = in_args[0];
auto item = ready_ops.front(); auto &out_arg = out_args[0];
ready_ops.pop();
ops_.emplace_back(item.first); if (IsSkipVar(in_arg)) {
// record level when pop from queue VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
op_level_[item.first] = item.second; << " is skipped in " << op_type;
continue;
}
for (auto node : item.first->outputs) { if (IsSkipVar(out_arg)) {
for (auto op : node->outputs) { VLOG(4) << "Cannot inplace because Output(" << out_param
--deps_map[op]; << ")=" << out_arg << " is skipped in " << op_type;
if (deps_map[op] == 0) ready_ops.push({op, item.second + 1}); continue;
} }
}
}
bool all_ops_checked = true; if (in_arg == out_arg) {
for (auto& node : nodes) { VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) { << " is the same with Output(" << out_param << ")=" << out_arg
all_ops_checked = false; << " in " << op_type;
LOG(WARNING) continue;
<< "Node " << node->Name() << " has not been checked. " }
<< "Maybe some passes have not handle node dependency rightly.";
break;
}
}
PADDLE_ENFORCE(all_ops_checked, "All ops deps should be 0 after analysis"); auto *in_node = FindNodeByName(in_arg, op_node->inputs);
} PADDLE_ENFORCE_NOT_NULL(in_node, "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
// return true if current op node depeneds on all other op that use the same if (!NodeCanReused(in_node)) {
// variable node VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const { << " is not reusable in " << op_type;
// get op list that rely on the same variable continue;
auto op_list = var->outputs; }
for (auto& op : op_list) {
if (op == current_op) continue;
VLOG(4) << " GraphView::CheckDeps : " << op->Name() << " & "
<< current_op->Name();
if (!CheckOpDeps(op, current_op)) return false;
VLOG(4) << "";
}
return true;
}
// check if op2 depends on op1's output
bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
if (VLOG_IS_ON(4)) {
auto print_op = [&](ir::Node* op, const char* name) {
std::ostringstream os;
os << " " << name << " : " << op->Name() << " ";
os << "Input args : ";
for (auto& arg : op->inputs) os << arg->Name() << " ";
os << "Output args : ";
for (auto& arg : op->outputs) os << arg->Name() << " ";
os << "Level : " << op_level_.at(op);
VLOG(4) << os.str();
};
print_op(op1, "OP1");
print_op(op2, "OP2");
}
if (op1 == op2) return true;
if (op_level_.at(op1) >= op_level_.at(op2)) return false;
for (auto& var : op2->inputs) if (!IsLastVersionVar(in_node)) {
if (var->inputs.size() > 0 && CheckOpDeps(op1, var->inputs[0])) return true; VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the last version in " << op_type;
continue;
}
return false; // If in_node is used as inputs of many ops, check whether all of that ops
} // depends on op_node. If not, in_node cannot be inplaced.
if (in_node->outputs.size() > 1 &&
!CheckOpDeps(op_node, in_node->outputs)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not lastly used in " << op_type;
continue;
}
ir::Node* GraphView::GetNodeByName(const std::string& name, auto *out_node = FindNodeByName(out_arg, op_node->outputs);
const std::vector<ir::Node*>& nodes) const { PADDLE_ENFORCE_NOT_NULL(out_node,
// nodes should be op->inputs/outputs "Output(%s)=%s cannot be found in op %s",
// node in same node do have different name. out_param, out_arg, op_type);
std::unordered_set<std::string> nodes_in_op;
bool has_dup_node =
std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) {
if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) {
if (nodes_in_op.count(node->Name())) return true;
nodes_in_op.emplace(node->Name());
}
return false;
});
PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!");
ir::Node* node = nullptr;
for (auto* it : nodes) {
if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue;
if (it->Name() == name) {
node = it;
break;
}
}
PADDLE_ENFORCE(node != nullptr,
string::Sprintf("Not found var %s in nodes!", name));
return node;
}
std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) { if (!NodeCanReused(out_node)) {
// get the pending ops depends on same var node. VLOG(4) << "Cannot inplace because Output(" << out_param
// because node also maybe a inplaced variable, so need to backtrack all the << ")=" << out_arg << " is not reusable in " << op_type;
// previous inplaced vars. continue;
std::vector<ir::Node*> pending_ops; }
ir::Node* p = node;
while (p != nullptr) {
pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end());
p = GetPrevCascadeInplacedVar(p);
}
return pending_ops;
}
void GraphView::Build(ir::Graph* g) { if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
// track the var nodes in correct order. VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
// Because we insert some new created node. Which may have data race between << " is not the same type with "
// nodes. << "Output(" << out_param << ")=" << out_arg << " in "
// resolve data harzards depends on the var nodes in right order. << op_type;
TopoSort(g); continue;
}
// fill the skip_set_ if (details::NodeSize(*in_node->Var()) !=
PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars)); details::NodeSize(*out_node->Var()) &&
auto& mem_opt_whitelist = g->Get<MemOptSkipVars>(kMemOptSkipVars); kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
<< "Output(" << out_param << ")=" << out_arg << " in "
<< op_type;
continue;
}
// 2. track the nodes which used by parameter server. // Debug Interface. Which would be skipped by the pass.
// these node can not be inplaced, otherwise trainer if (out_arg == FLAGS_memory_optimize_debug) {
// pserver can not find each other name. VLOG(4) << "Skiped var by force. FLAGS_memory_optimize_debug="
auto update_skip_set = [&](ir::Node* node) { << out_node->Name();
for (auto& in : node->inputs) { continue;
if (in->IsVar() && in->Var() != nullptr) {
skip_set_.emplace(in->Name());
} }
}
for (auto& out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name());
}
};
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
// avoid optimize the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) update_skip_set(node);
if (node->Name() == "send") update_skip_set(node); VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name()
if (node->Name() == "recv") update_skip_set(node); << " in " << op_type;
if (node->Name() == "prefetch") update_skip_set(node); RenameInOut(op_node, in_node, out_node);
}
} }
} }
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::InSkipSet(const std::string& var) const {
return skip_set_.count(var);
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class GraphView {
public:
GraphView() = default;
void Build(ir::Graph* g);
const std::vector<ir::Node*>& AllOps();
ir::Node* GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const;
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future.
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
bool CheckDeps(ir::Node* var, ir::Node* current_op) const;
bool CheckOpDeps(ir::Node* op1, ir::Node* op2) const;
void TopoSort(ir::Graph* g);
private:
std::vector<ir::Node*> ops_;
std::unordered_set<std::string> skip_set_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
std::unordered_map<ir::Node*, uint32_t> op_level_;
};
// swap pairs in sequence
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
class InplacePass : public ir::Pass {
public:
InplacePass();
protected:
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
private:
const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
mutable std::unordered_set<std::string> whitelist_;
mutable GraphView view_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -56,7 +56,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, ...@@ -56,7 +56,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
std::unordered_set<OpHandleBase *> visited; std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q; std::queue<OpHandleBase *> q;
q.push(op); q.push(op);
do { while (!q.empty()) {
op = q.front(); op = q.front();
q.pop(); q.pop();
for (auto &pending_op : pending_ops_.at(op)) { for (auto &pending_op : pending_ops_.at(op)) {
...@@ -65,9 +65,10 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, ...@@ -65,9 +65,10 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
if (!callback(pending_op)) { if (!callback(pending_op)) {
return false; return false;
} }
q.push(pending_op);
} }
} }
} while (!q.empty()); }
return true; return true;
} }
......
...@@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor { ...@@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor {
const OpGraphView graph_; const OpGraphView graph_;
}; };
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
} while (!q.empty());
return nullptr;
}
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
/** /**
* Shrink op dependencies according to no need buffer vars. * Shrink op dependencies according to no need buffer vars.
* *
...@@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency( ...@@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency(
} }
} }
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
}
return nullptr;
}
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
return {};
}
PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");
// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
}
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars = auto &last_live_ops_of_vars =
...@@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
ShrinkDepsOpFunctor shrink_func( ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph)); ir::FilterByNodeWrapper<OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) { for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not // Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies. // compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second); VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
if (var_desc == nullptr || var_desc->Persistable()) { if (var_desc == nullptr || var_desc->Persistable()) {
continue; continue;
} }
...@@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first; auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second; auto &var_handles = name_var_pair.second;
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) { ++iter) {
bool ok; VLOG(10) << "Try to find last living ops of " << var_name << " "
auto result = << (iter - var_handles.rbegin()) << " time";
ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok); LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);
// Seldomly, some vars may have no pending or preceding computation ops // Seldomly, some vars may have no pending or preceding computation ops
// Just break; // Just break;
if (!ok) break; if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; break;
}
size_t original_op_deps = result.size(); if (status == LastLiveOpSearchStatus::kShouldPrecede) {
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
VLOG(10) << "Try to precede reference count computing at var " VLOG(10) << "Try to precede reference count computing at var "
<< var_name; << var_name;
continue; continue;
} }
size_t final_op_deps = result.size(); PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
if (final_op_deps < original_op_deps) {
VLOG(5) << "Shrink op deps from " << original_op_deps << " to "
<< final_op_deps;
}
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name); var_name);
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size()); ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break; break;
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -27,9 +26,15 @@ ...@@ -27,9 +26,15 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
USE_PASS(inplace_pass);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
return ir::PassRegistry::Instance().Get("inplace_pass");
}
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
NOP(const std::string& type, const VariableNameMap& inputs, NOP(const std::string& type, const VariableNameMap& inputs,
...@@ -202,7 +207,7 @@ ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) { ...@@ -202,7 +207,7 @@ ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) {
std::unique_ptr<ir::Graph> test_SingleOpInplaceInToOut( std::unique_ptr<ir::Graph> test_SingleOpInplaceInToOut(
std::unique_ptr<ir::Graph> g) { std::unique_ptr<ir::Graph> g) {
std::unique_ptr<details::InplacePass> pass(new details::InplacePass()); auto pass = CreateInplacePass();
ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op"); ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_NE(op_node, nullptr); EXPECT_NE(op_node, nullptr);
pass->Apply(g.get()); pass->Apply(g.get());
...@@ -268,7 +273,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { ...@@ -268,7 +273,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass()); auto pass = CreateInplacePass();
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
ASSERT_TRUE(op_node != nullptr); ASSERT_TRUE(op_node != nullptr);
...@@ -304,7 +309,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { ...@@ -304,7 +309,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass()); auto pass = CreateInplacePass();
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
ASSERT_TRUE(op_node != nullptr); ASSERT_TRUE(op_node != nullptr);
......
...@@ -108,11 +108,18 @@ class Node { ...@@ -108,11 +108,18 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos; Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_,
"Must be type of variable");
name_ = new_name;
var_desc_->SetName(new_name);
}
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
protected: protected:
const std::string name_; std::string name_;
std::unique_ptr<VarDesc> var_desc_; std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
......
...@@ -220,16 +220,6 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -220,16 +220,6 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class SoftmaxInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -74,3 +74,7 @@ class TestIrInplace(TestParallelExecutorBase): ...@@ -74,3 +74,7 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册