未验证 提交 4e1bc6e8 编写于 作者: Z Zeng Jinle 提交者: GitHub

Rewrite inplace pass and fix gc bug (#17126)

* fix op graph view
test=develop

* rewrite inplace pass and fix reference count pass bug
test=develop

* fix unittest failed
test=develop

* follow comments, test=develop
上级 08773b60
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class GraphView {
public:
GraphView() = default;
void Build(ir::Graph* g);
const std::vector<ir::Node*>& AllOps();
ir::Node* GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const;
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future.
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
bool CheckDeps(ir::Node* var, ir::Node* current_op) const;
bool CheckOpDeps(ir::Node* op1, ir::Node* op2) const;
void TopoSort(ir::Graph* g);
private:
std::vector<ir::Node*> ops_;
std::unordered_set<std::string> skip_set_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
std::unordered_map<ir::Node*, uint32_t> op_level_;
};
// swap pairs in sequence
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
class InplacePass : public ir::Pass {
public:
InplacePass();
protected:
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
private:
const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
mutable std::unordered_set<std::string> whitelist_;
mutable GraphView view_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -56,7 +56,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
q.push(op);
do {
while (!q.empty()) {
op = q.front();
q.pop();
for (auto &pending_op : pending_ops_.at(op)) {
......@@ -65,9 +65,10 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
if (!callback(pending_op)) {
return false;
}
q.push(pending_op);
}
}
}
} while (!q.empty());
return true;
}
......
......@@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor {
const OpGraphView graph_;
};
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
} while (!q.empty());
return nullptr;
}
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
/**
* Shrink op dependencies according to no need buffer vars.
*
......@@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
}
return nullptr;
}
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
return {};
}
PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");
// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
}
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
......@@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
}
......@@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second;
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
bool ok;
auto result =
ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok);
VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);
// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (!ok) break;
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
if (status == LastLiveOpSearchStatus::kFailure) {
break;
}
size_t original_op_deps = result.size();
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
if (status == LastLiveOpSearchStatus::kShouldPrecede) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}
size_t final_op_deps = result.size();
if (final_op_deps < original_op_deps) {
VLOG(5) << "Shrink op deps from " << original_op_deps << " to "
<< final_op_deps;
}
PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
......
......@@ -18,7 +18,6 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -27,9 +26,15 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
USE_PASS(inplace_pass);
namespace paddle {
namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
return ir::PassRegistry::Instance().Get("inplace_pass");
}
class NOP : public OperatorBase {
public:
NOP(const std::string& type, const VariableNameMap& inputs,
......@@ -202,7 +207,7 @@ ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) {
std::unique_ptr<ir::Graph> test_SingleOpInplaceInToOut(
std::unique_ptr<ir::Graph> g) {
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_NE(op_node, nullptr);
pass->Apply(g.get());
......@@ -268,7 +273,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
ASSERT_TRUE(op_node != nullptr);
......@@ -304,7 +309,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
ASSERT_TRUE(op_node != nullptr);
......
......@@ -108,11 +108,18 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_,
"Must be type of variable");
name_ = new_name;
var_desc_->SetName(new_name);
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
const std::string name_;
std::string name_;
std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_;
Type type_;
......
......@@ -220,16 +220,6 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
class SoftmaxInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
}
};
} // namespace operators
} // namespace paddle
......
......@@ -74,3 +74,7 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册