未验证 提交 7cd24b13 编写于 作者: D dzhwinter 提交者: GitHub

add ir memory optimize. (#14530)

* follow comments. test=develop

* Fix typo

* fix compile error. test=develop

* merge develop branch. test=develop

* Remove set_equal

* Polish code

* Delete unused functions

test=develop

* polish code. test=develop

* follow comment

* polish code.

* fix windows compile error. test=develop

* fix op handle.

* rerun ci. test=develop

* rerun ci. test=develop

* rerun macci. test=develop

* polish code. test=develop

* rewrite sort code. test=develop

* remove unused code. test=develop

* fix tests. test=develop

* fix conflict. test=develop

* follow comment. test=develop

* merge develop branch. test=develop

* fix tests. test=develop

* remove ToTypeIndex. test=develop

* rerun ci. test=develop
上级 fd1d2c89
...@@ -50,8 +50,10 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -50,8 +50,10 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
...@@ -63,7 +65,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he ...@@ -63,7 +65,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph)
cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
...@@ -84,4 +91,5 @@ cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fuse ...@@ -84,4 +91,5 @@ cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fuse
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass) fuse_elewise_add_act_pass multi_batch_merge_pass
memory_optimize_pass)
此差异已折叠。
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// sort op in bfs order
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
class ControlFlowGraph;
class AnalysisVarPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
// update program descs
void RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var, size_t idx) const;
// update ir nodes
void RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, size_t idx,
ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const;
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node) const;
// scan subblock and collect the output/input variables.
std::unordered_set<std::string> GetSubBlockVars(
const std::unordered_set<ir::Node*>&) const;
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc) const;
private:
// Reuse Node Pool, Owned.
mutable OrderedNodePairPool pool_;
// controlflow Graph
mutable std::unique_ptr<ControlFlowGraph> cfg_;
// skip set
mutable std::unordered_set<std::string> skip_set_;
// var nodes
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
class ControlFlowGraph {
public:
ControlFlowGraph() = default;
// For IR Graph in parallelexecutor
explicit ControlFlowGraph(const ir::Graph& graph);
void LiveVariableAnalysis();
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
private:
void BuildCFGGraph();
void ConnectNodes();
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
// successors ops use the output variables.
NodeListMap successors_;
// predecessors ops generated input variables.
NodeListMap predecessors_;
// variables lived before run current op.
VarSetMap live_in_;
// variables lived after run current op.
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::vector<ir::Node*> ops_; // op sequence by topology sort
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/analysis_var_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace paddle {
namespace framework {
namespace details {
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(proto::VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"e"});
}
return prog;
}
template <typename Container>
inline static std::string DebugString(const Container& c) {
std::stringstream ss;
for (auto& item : c) {
ss << item << " ";
}
return ss.str();
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis();
// test assign op
ASSERT_TRUE((std::set<std::string>{"a"} == cfg.LiveIn(cfg.Ops()[0])));
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveOut(cfg.Ops()[0])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveIn(cfg.Ops()[1])));
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveOut(cfg.Ops()[1])));
// test sum op
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveIn(cfg.Ops()[2])));
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveOut(cfg.Ops()[2])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveIn(cfg.Ops()[3])));
ASSERT_TRUE((std::set<std::string>{} == cfg.LiveOut(cfg.Ops()[3])));
}
// 1. normal test
TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps();
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 2. remove some op_desc
TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->outputs.back()->Name() == "e") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
graph.RemoveNode(found_node);
graph.RemoveNode(e);
// other node keeps the same order
auto remain_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < remain_nodes.size(); ++i) {
auto node = remain_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc();
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto op_descs = prog.Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
op_descs.insert(op_descs.begin() + 4, op);
auto nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 4. add and delete some op_desc
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// remove sum node
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
auto nodes = graph.Nodes();
for (auto node : nodes) {
if (node->Name() == "sum") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* c = find_node_in_graph("c");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(c->outputs.begin(), c->outputs.end(), found_node);
ir::Node* pending_op = found_node->outputs[0]->outputs[0];
graph.RemoveNode(e);
graph.RemoveNode(pending_op);
graph.RemoveNode(found_node);
}
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.insert(op_descs.begin() + 2, op);
// check the order
auto mynodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < mynodes.size(); ++i) {
auto node = mynodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 5. add and replace some op_desc inplace.
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
auto op_descs = prog.Block(0).AllOps();
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.emplace_back(op);
// replace op_desc inplace
auto nodes = graph.Nodes();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->Op() && node->Name() == "assign") {
if (node->outputs.size() == 1 && node->outputs[0]->Name() == "e") {
found_node = node;
break;
}
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(e->inputs.begin(), e->inputs.end(), found_node);
graph.RemoveNode(found_node);
}
op_descs.erase(op_descs.begin() + 3);
auto replace_op = prog.MutableBlock(0)->AppendOp();
replace_op->SetType("sum");
replace_op->SetInput("X", {"d", "d1"});
replace_op->SetOutput("Out", {"e"});
{
ir::Node* sum2 = graph.CreateOpNode(replace_op);
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
ir::Node* d1 = find_node_in_graph("d1");
sum2->inputs.emplace_back(d);
sum2->inputs.emplace_back(d1);
sum2->outputs.emplace_back(e);
e->inputs.emplace_back(sum2);
d->outputs.emplace_back(sum2);
d1->outputs.emplace_back(sum2);
}
op_descs.emplace_back(replace_op);
// compare op order
auto graph_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < graph_nodes.size(); ++i) {
auto node = graph_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -14,11 +14,16 @@ limitations under the License. */ ...@@ -14,11 +14,16 @@ limitations under the License. */
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include <glog/logging.h>
#include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle { namespace paddle {
...@@ -69,6 +74,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -69,6 +74,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
VLOG(1) << "CollectiveContext:" << context->String(); VLOG(1) << "CollectiveContext:" << context->String();
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) {
auto analysis_var_pass = AppendPass("analysis_var_pass");
}
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass"); auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
...@@ -79,8 +92,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -79,8 +92,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
multi_devices_print_pass->SetNotOwned<const std::string>( const std::string graph_path =
"debug_graphviz_path", &strategy_.debug_graphviz_path_); string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>( multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter); "graph_printer", new details::GraphvizSSAGraphPrinter);
} }
...@@ -127,7 +143,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -127,7 +143,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") { if (pass->Type() == "multi_devices_pass") {
pass->Erase("places"); pass->Erase("places");
...@@ -145,6 +160,17 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -145,6 +160,17 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "analysis_var_pass") {
const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
all_op_descs); // take ownership
graph->Set<GraphNodePool>(kGraphNodePool,
new GraphNodePool); // take ownership
pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, all_op_descs);
} else if (pass->Type() == "sequential_execution_pass") { } else if (pass->Type() == "sequential_execution_pass") {
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
...@@ -166,6 +192,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -166,6 +192,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} }
return graph; return graph;
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -176,6 +203,7 @@ USE_PASS(multi_batch_merge_pass); ...@@ -176,6 +203,7 @@ USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass); USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(analysis_var_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass); USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -60,8 +60,15 @@ struct BuildStrategy { ...@@ -60,8 +60,15 @@ struct BuildStrategy {
kCustomized = 2, kCustomized = 2,
}; };
enum class OptimizeStrategy {
// To be Implemented,bruteforce, recursive compute unused var names.
kBruteForce = 0,
kControlFlowGraph = 1, // use cfg_graph algorithm, faster speed.
};
ReduceStrategy reduce_{ReduceStrategy::kAllReduce}; ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
OptimizeStrategy strategy_{OptimizeStrategy::kControlFlowGraph};
std::string debug_graphviz_path_{""}; std::string debug_graphviz_path_{""};
...@@ -69,6 +76,10 @@ struct BuildStrategy { ...@@ -69,6 +76,10 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
bool memory_optimize_{false};
bool memory_early_delete_{false};
bool enable_sequential_execution_{false}; bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
class EarlyDeleteOpHandle : public OpHandleBase {
public:
EarlyDeleteOpHandle(ir::Node* node, const Scope* scope,
const platform::Place& place,
const std::vector<std::string>& names,
GarbageCollector* gc)
: OpHandleBase(node),
scope_(scope),
place_(place),
names_(names),
gc_(gc) {
#ifdef PADDLE_WITH_CUDA
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(place);
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
#endif
}
~EarlyDeleteOpHandle() {
#ifdef PADDLE_WITH_CUDA
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
#endif
}
std::string Name() const override { return "early_delete"; }
protected:
void RunImpl() override {
std::vector<std::shared_ptr<memory::Allocation>> tensors;
auto* local_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope*>();
for (auto& var_name : names_) {
auto* var = local_scope->FindVar(var_name);
PADDLE_ENFORCE(var != nullptr,
string::Sprintf("Local Scope not has var %s", var_name));
if (var->IsType<LoDTensor>()) {
tensors.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
tensors.emplace_back(var->GetMutable<SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
LoDTensorArray* tensor_array = var->GetMutable<LoDTensorArray>();
for (auto& tensor : *tensor_array) {
tensors.emplace_back(tensor.MoveMemoryHolder());
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
if (platform::is_cpu_place(place_)) {
ClearCPUTensors(tensors);
} else {
ClearGPUTensors(tensors);
}
}
void ClearCPUTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
auto* gc = dynamic_cast<CPUGarbageCollector*>(gc_);
if (gc != nullptr) {
gc->Add(tensors);
}
}
void ClearGPUTensors(
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
#ifdef PADDLE_WITH_CUDA
auto* gc = dynamic_cast<StreamGarbageCollector*>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector*>(gc_) != nullptr;
#endif
}
const Scope* scope_;
const platform::Place place_;
std::vector<std::string> names_;
GarbageCollector* gc_;
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext* dev_ctx_;
cudaEvent_t event_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
std::queue<VarHandleBase*> queue;
queue.push(var_in);
do {
auto* var = queue.front();
queue.pop();
for (auto* op : var->PendingOps()) {
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
return compute_op;
}
for (auto* out_var : op->Outputs()) {
queue.push(out_var);
}
}
} while (!queue.empty());
return nullptr;
}
std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto& graph_pool = Get<GraphNodePool>(kGraphNodePool);
auto& gcs = Get<GarbageCollectorMap>(kGarbageCollector);
std::unordered_map<std::string, std::unordered_set<OpDesc*>> unlived_vars;
unlived_vars.reserve(graph_pool.size());
for (auto& pair : graph_pool) {
unlived_vars.insert(std::make_pair(pair.first, pair.second));
}
auto compare_and_insert_early_delete_op = [&](
OpHandleBase* op, const std::vector<VarHandleBase*>& vars) {
if (unlived_vars.empty()) return;
// unlived vars can be deleted after the last used op has finished.
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
const auto& places = Get<std::vector<platform::Place>>(kAllPlaces);
for (auto& var : vars) {
auto* var_handle = dynamic_cast<VarHandle*>(var);
auto var_name = var->Node()->Name();
auto& var_place = var_handle->place_;
if (unlived_vars.count(var_name) == 0) continue;
if (!unlived_vars[var_name].empty()) {
if (compute_op != nullptr &&
unlived_vars[var_name].count(compute_op->Node()->Op()) != 0) {
unlived_vars[var_name].erase(compute_op->Node()->Op());
}
continue;
}
if (var_handle == nullptr || !var_handle->Node()->IsVar() ||
var_handle->Node()->IsCtrlVar())
continue;
// shameless copyed from reference count pass.
if (compute_op == nullptr) {
// use next computation op scope
compute_op = FindNextComputationOpHandle(var_handle);
}
auto* early_delete_node =
graph->CreateEmptyNode("early_delete", ir::Node::Type::kOperation);
GarbageCollector* gc = gcs.at(places[compute_op->GetScopeIdx()]).get();
auto* early_delete_handle = new EarlyDeleteOpHandle(
early_delete_node, compute_op->GetScope(), var_place, {var_name}, gc);
if (compute_op->Outputs().empty()) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
early_delete_handle->AddInput(compute_op->Outputs().front());
VLOG(5) << "Add early delete op " << var_name << " to Operator"
<< compute_op->Name();
}
};
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto& op : all_ops) {
compare_and_insert_early_delete_op(op, op->Inputs());
compare_and_insert_early_delete_op(op, op->Outputs());
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_early_delete_pass,
paddle::framework::details::MemoryEarlyDeletePass)
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
.RequireGraphAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class MemoryEarlyDeletePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include <iostream>
#include <sstream>
#include <string>
namespace paddle {
namespace framework {
namespace details {
size_t NodeSizeInBytes(ir::Node* n) {
auto* desc = FindVarDescInBlock(n);
auto shape = desc->GetShape();
size_t type_size = SizeOfType(desc->GetDataType());
int size = 1;
for (auto& s : shape) {
size *= s;
}
return type_size * std::abs(size);
}
std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss;
ss << var->Name();
ss << "[";
try {
auto shape = var->GetShape();
for (size_t i = 0; i < shape.size(); ++i) {
if (i != shape.size() - 1) {
ss << shape[i] << ",";
} else {
ss << shape[i];
}
}
ss << "]";
} catch (...) {
ss << "Var has no VarDesc !!! Name:" << var->Name();
}
return ss.str();
}
std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var));
}
// return DebugString(var->Var()); }
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc.
VarDesc* FindVarDescInBlock(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
BlockDesc* block = n->inputs[0]->Op()->Block();
PADDLE_ENFORCE(block->HasVar(n->Name()),
string::Sprintf("Block do not has var %s", n->Name()));
return block->FindVar(n->Name());
}
struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs);
auto* rhs_desc = FindVarDescInBlock(rhs);
auto lhs_shape = lhs_desc->GetShape();
auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
return NodeSizeInBytes(lhs) <= NodeSizeInBytes(rhs);
} else {
return false;
}
}
};
void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
PADDLE_ENFORCE(op->IsOp());
if (mark_table_.count(var->Name()) != 0) {
mark_table_[var->Name()]->second.insert(op);
return;
}
auto* var_desc = FindVarDescInBlock(var);
auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]);
NodeComparator compare_node;
Iter it = nodes_.begin();
while (it != nodes_.end()) {
auto* cache_desc = FindVarDescInBlock(it->first);
int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) {
if (compare_node(it->first, var)) {
++it;
} else {
break;
}
} else if (cache_batch_size == -1 && batch_size != -1) {
++it;
} else if (cache_batch_size != -1 && batch_size == -1) {
break;
}
}
it =
nodes_.insert(it, std::make_pair(var, std::unordered_set<ir::Node*>{op}));
mark_table_[var->Name()] = it;
}
int OrderedNodePairPool::GetIndex(ir::Node* var) {
return std::distance(nodes_.begin(), mark_table_[var->Name()]);
}
ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
ir::Node* found_node = nullptr;
NodeComparator compare_node;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
if (compare_node(var, it->first)) {
found_node = it->first;
break;
}
}
return found_node;
}
void OrderedNodePairPool::Erase(ir::Node* var) {
PADDLE_ENFORCE(mark_table_.count(var->Name()));
nodes_.erase(mark_table_[var->Name()]);
mark_table_.erase(var->Name());
}
std::string OrderedNodePairPool::ToString() const {
std::stringstream ss;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
ss << DebugString(it->first) << " ";
}
return ss.str();
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <list>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kFetchedVars[] = "fetched_vars";
constexpr char kGraphNodePool[] = "graph_node_pool";
// NOTE(dzh): Variable and the operators use the var.
// for early delete pass.
// Because analysis var pass build base on ir::Node, which maybe released
// or modified between passes, so we use OpDesc* to mark ops.
using GraphNodePool = std::vector<
std::pair<std::string /*var node*/, std::unordered_set<OpDesc*> /* ops */>>;
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
// in fluid, -1 means the batch_size is determined in runtime.
// the node batch_size equal -1 always ranking in the front than the node not.
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
class OrderedNodePairPool {
public:
using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
using Iter = typename std::list<NodePair>::iterator;
using ConstIter = typename std::list<NodePair>::const_iterator;
void Insert(ir::Node* var, ir::Node* op);
void Erase(ir::Node* var);
bool Has(ir::Node* var) { return mark_table_.count(var->Name()); }
ir::Node* NodeMatch(ir::Node* var) const;
// map store non-const iterator, can not promise const
int GetIndex(ir::Node* var);
// pool all node to string
std::string ToString() const;
Iter begin() { return nodes_.begin(); }
Iter end() { return nodes_.end(); }
ConstIter begin() const { return nodes_.begin(); }
ConstIter end() const { return nodes_.end(); }
size_t size() const { return nodes_.size(); }
private:
// for searching.
std::unordered_map<std::string, Iter> mark_table_;
// node swap pairs. var -> ops dep var
std::list<NodePair> nodes_;
};
// node memory size in bytes
size_t NodeSizeInBytes(ir::Node* n);
std::string DebugString(ir::Node* var);
// std::string DebugString(VarDesc* var);
VarDesc* FindVarDescInBlock(ir::Node* n);
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
namespace paddle {
namespace framework {
namespace details {
TEST(OrderedNodePairPool, Normal) {
OrderedNodePairPool pool;
std::vector<std::unique_ptr<ir::Node>> nodes;
// clang-format off
std::vector<std::vector<int64_t>> shapes = {{-1, 10},
{-1, 20},
{1, 2},
{5, 2},
{10, 20},
{-1, 2, 5},
{-1, 1, 5},
{-1, 1}};
// clang-format on
const int COUNT = shapes.size();
ProgramDesc prog;
BlockDesc* block_desc = prog.MutableBlock(0);
auto* op_desc = block_desc->AppendOp();
op_desc->SetType("dummy");
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
for (int i = 0; i < COUNT; ++i) {
auto desc = block_desc->Var(std::to_string(i));
desc->SetShape(shapes[i]);
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
for (auto& node : nodes) {
pool.Insert(node.get(), op.get());
}
// assert its order and interface.
std::cout << pool.ToString() << std::endl;
pool.Erase(nodes.front().get());
std::cout << pool.ToString() << std::endl;
ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1));
ASSERT_EQ(pool.GetIndex(nodes.back().get()), 0);
{
auto v1 = block_desc->Var("11");
v1->SetShape({-1, 256, 56, 56});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1);
node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get());
ASSERT_EQ(cache, nullptr);
}
{
auto v2 = block_desc->Var("12");
v2->SetShape({-1, 2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2);
node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get());
ASSERT_EQ(pool.GetIndex(cache), 2); // match 6:[-1,2,5]
}
{
auto v3 = block_desc->Var("13");
v3->SetShape({2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3);
node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get());
ASSERT_EQ(pool.GetIndex(cache), 5); // match 4:[5,2]
}
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -85,4 +85,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -85,4 +85,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_devices_print_pass, REGISTER_PASS(multi_devices_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter); paddle::framework::details::SSAGraghBuilderWithPrinter)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <glog/logging.h>
#include <fstream> #include <fstream>
#include <iosfwd> #include <iosfwd>
#include <ostream> #include <ostream>
...@@ -24,6 +25,8 @@ namespace paddle { ...@@ -24,6 +25,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
virtual ~SSAGraphPrinter() {} virtual ~SSAGraphPrinter() {}
...@@ -40,7 +43,7 @@ class SSAGraghBuilderWithPrinter : public ir::Pass { ...@@ -40,7 +43,7 @@ class SSAGraghBuilderWithPrinter : public ir::Pass {
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path"))); new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout); Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph; return graph;
......
...@@ -25,7 +25,7 @@ namespace paddle { ...@@ -25,7 +25,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@";
// Wraps ir::Node and provide helper utilities. // Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
......
...@@ -162,7 +162,10 @@ void Graph::ResolveHazard( ...@@ -162,7 +162,10 @@ void Graph::ResolveHazard(
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs; const auto &read_ops = (*it_old)->outputs;
PADDLE_ENFORCE(write_op, "The write_op should not be empty."); PADDLE_ENFORCE(
write_op,
string::Sprintf("The write_op of var %s should not be empty.",
(*it_new)->Name()));
// Add write after write dependence // Add write after write dependence
ir::Node *upstream_op = ir::Node *upstream_op =
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <iosfwd> #include <iosfwd>
#include <ostream> #include <ostream>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
DEFINE_string(print_sub_graph_dir, "", DEFINE_string(print_sub_graph_dir, "",
...@@ -121,7 +122,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -121,7 +122,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
} }
size_t GraphNum(const Graph &graph) { size_t GraphNum(const Graph &graph) {
std::unordered_set<ir::Node *> nodes = graph.Nodes(); std::unordered_set<ir::Node *> nodes(graph.Nodes());
std::unordered_set<ir::Node *> visited_nodes; std::unordered_set<ir::Node *> visited_nodes;
visited_nodes.reserve(nodes.size()); visited_nodes.reserve(nodes.size());
std::deque<ir::Node *> q_nodes; std::deque<ir::Node *> q_nodes;
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Test if the graph contains circle. // Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
......
...@@ -30,6 +30,14 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name, ...@@ -30,6 +30,14 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
return std::unique_ptr<Node>(new Node(name, type)); return std::unique_ptr<Node>(new Node(name, type));
} }
std::unique_ptr<Node> CreateNodeForTest(VarDesc *var_desc) {
return std::unique_ptr<Node>(new Node(var_desc));
}
std::unique_ptr<Node> CreateNodeForTest(OpDesc *op_desc) {
return std::unique_ptr<Node>(new Node(op_desc));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -125,6 +124,8 @@ class Node { ...@@ -125,6 +124,8 @@ class Node {
friend class Graph; friend class Graph;
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name, friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type); Node::Type type);
friend std::unique_ptr<Node> CreateNodeForTest(VarDesc* var_desc);
friend std::unique_ptr<Node> CreateNodeForTest(OpDesc* op_desc);
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
...@@ -152,7 +153,9 @@ class Node { ...@@ -152,7 +153,9 @@ class Node {
std::unique_ptr<Node> CreateNodeForTest(const std::string& name, std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type); Node::Type type);
std::unique_ptr<Node> CreateNodeForTest(VarDesc* var_desc);
std::unique_ptr<Node> CreateNodeForTest(OpDesc* op_desc);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include <algorithm>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -93,6 +94,7 @@ class ParallelExecutorPrivate { ...@@ -93,6 +94,7 @@ class ParallelExecutorPrivate {
} }
} }
BuildStrategy build_strategy_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; // not owned Scope *global_scope_; // not owned
...@@ -169,6 +171,14 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -169,6 +171,14 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph)); graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
if (build_strategy_.memory_early_delete_) {
auto early_delete_pass =
ir::PassRegistry::Instance().Get("memory_early_delete_pass");
early_delete_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = early_delete_pass->Apply(std::move(graph));
}
VLOG(10) << "MemoryEarlyDeletePass Applied.";
} }
return graph; return graph;
...@@ -189,6 +199,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -189,6 +199,7 @@ ParallelExecutor::ParallelExecutor(
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
...@@ -245,7 +256,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -245,7 +256,6 @@ ParallelExecutor::ParallelExecutor(
build_strategy.Apply(main_program, member_->places_, loss_var_name, build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(std::move(graph), graph = member_->PrepareGCAndRefCnts(std::move(graph),
...@@ -280,10 +290,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -280,10 +290,12 @@ ParallelExecutor::ParallelExecutor(
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, member_->places_,
std::move(graph)));
} else { } else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, member_->places_,
std::move(graph)));
} }
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
...@@ -423,5 +435,6 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -423,5 +435,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(memory_early_delete_pass);
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
...@@ -74,6 +74,22 @@ TEST(Tensor, MutableData) { ...@@ -74,6 +74,22 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}), p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
platform::CPUPlace()); platform::CPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
float* p3 = nullptr;
float* p4 = nullptr;
// set src_tensor a different type but smaller size.
// memory block is supposed to be unchanged.
auto* tmp = src_tensor.mutable_data<uint8_t>(framework::make_ddim({2, 2}),
platform::CPUPlace());
p3 = reinterpret_cast<float*>(tmp);
EXPECT_EQ(p1, p3);
// set src_tensor a different type but bigger size.
// memory block is supposed to be changed.
auto* tmp2 = src_tensor.mutable_data<double>(
framework::make_ddim({2, 2, 3}), platform::CPUPlace());
p4 = reinterpret_cast<float*>(tmp2);
EXPECT_NE(p1, p4);
} }
// Not sure if it's desired, but currently, Tensor type can be changed. // Not sure if it's desired, but currently, Tensor type can be changed.
{ {
......
...@@ -960,6 +960,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -960,6 +960,14 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op, to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC") it may make the execution faster. Default False)DOC")
.def_property(
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
[](BuildStrategy &self, bool b) { self.memory_optimize_ = b; })
.def_property(
"memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; },
[](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; })
.def("_finalize_strategy_and_create_passes", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true); return self.CreatePassesFromStrategy(true);
......
...@@ -150,7 +150,7 @@ def __bootstrap__(): ...@@ -150,7 +150,7 @@ def __bootstrap__():
read_env_flags += [ read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus' 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
...@@ -39,6 +39,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -39,6 +39,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True,
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=False,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
optimizer=fluid.optimizer.Adam, optimizer=fluid.optimizer.Adam,
use_fast_executor=False, use_fast_executor=False,
...@@ -82,6 +83,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -82,6 +83,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True build_strategy.remove_unnecessary_lock = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
def _feed_data_helper(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
return img, label
def simple_fc_net(use_feed):
x, y = _feed_data_helper(use_feed)
hidden_layer = 4
for _ in range(hidden_layer):
x = fluid.layers.fc(input=x, size=20, act='relu')
y_predict = fluid.layers.fc(input=x, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def fc_with_inplace_net(use_feed):
x, y = _feed_data_helper(use_feed)
fc = fluid.layers.fc(input=x, size=20, act='relu')
fc = fluid.layers.fc(input=fc, size=10, act='relu')
reshape = fluid.layers.reshape(x=fc, shape=[-1, 2, 5])
reshape = fluid.layers.reshape(x=reshape, shape=[-1, 5, 2])
y_predict = fluid.layers.fc(input=reshape, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder)
def _dummy_data(self):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare_ir_and_python_memory_optimize(self, model, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._dummy_data()
first_loss0, last_loss0 = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
memory_opt=False,
use_ir_memory_optimize=False)
first_loss1, last_loss1 = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
memory_opt=False,
use_ir_memory_optimize=True)
for loss in zip(first_loss0, first_loss1):
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6)
for loss in zip(last_loss0, last_loss1):
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6)
def test_simple_fc_net(self):
self._compare_ir_and_python_memory_optimize(simple_fc_net, False)
self._compare_ir_and_python_memory_optimize(simple_fc_net, True)
def test_fc_with_reshape_net(self):
self._compare_ir_and_python_memory_optimize(fc_with_inplace_net, False)
self._compare_ir_and_python_memory_optimize(fc_with_inplace_net, True)
if __name__ == '__main__':
unittest.main()
...@@ -43,6 +43,7 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"), ...@@ -43,6 +43,7 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
("conditional_block", "conditional_block_grad")] ("conditional_block", "conditional_block_grad")]
PRINT_LOG = False PRINT_LOG = False
FLAGS_memory_optimize = ""
class OrderedSet(MutableSet): class OrderedSet(MutableSet):
...@@ -121,6 +122,7 @@ class ControlFlowGraph(object): ...@@ -121,6 +122,7 @@ class ControlFlowGraph(object):
self._defs = defaultdict(OrderedSet) self._defs = defaultdict(OrderedSet)
self._live_in = defaultdict(OrderedSet) self._live_in = defaultdict(OrderedSet)
self._live_out = defaultdict(OrderedSet) self._live_out = defaultdict(OrderedSet)
self._skip_opt = skip_opt self._skip_opt = skip_opt
self.pool = [] self.pool = []
...@@ -144,7 +146,6 @@ class ControlFlowGraph(object): ...@@ -144,7 +146,6 @@ class ControlFlowGraph(object):
for i in range(self.op_size): for i in range(self.op_size):
self._uses[i].update(self._ops[i].input_arg_names()) self._uses[i].update(self._ops[i].input_arg_names())
self._defs[i].update(self._ops[i].output_arg_names()) self._defs[i].update(self._ops[i].output_arg_names())
self._live_in[i] = self._uses[i]
def _update_graph(self, old_name, new_name, begin_idx=0): def _update_graph(self, old_name, new_name, begin_idx=0):
for i in range(begin_idx, self.op_size): for i in range(begin_idx, self.op_size):
...@@ -177,20 +178,52 @@ class ControlFlowGraph(object): ...@@ -177,20 +178,52 @@ class ControlFlowGraph(object):
worklist.append(d) worklist.append(d)
def _fill_pool(self, i, is_forward): def _fill_pool(self, i, is_forward):
def comparator(x, cache):
x_shape = x[1]
cache_shape = cache[1]
x_size = abs(reduce(lambda x, y: x * y, x_shape))
cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
if (x_shape[0] == -1 and cache_shape[0] == -1) or \
(x_shape[0] != -1 and cache_shape[0] != -1) :
return x_size <= cache_size
else:
return False
def find_var_in_block(x):
known_vars = set()
for op in self._ops:
known_vars.update(op.output_arg_names())
return x in known_vars
block_desc = self._ops[i].block() block_desc = self._ops[i].block()
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
# NOTE: must sort the in_diff set for cases that get different cache var. # NOTE: must sort the in_diff set for cases that get different cache var.
# FIXME(typhoonzero): maybe use a "sorted set" is better than this. # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
can_optimize = [ can_optimize = [
x for x in in_diff x for x in sorted(in_diff)
if self._check_var_validity(block_desc, x, is_forward) if self._check_var_validity(block_desc, x, is_forward)
] ]
if can_optimize: if can_optimize:
for var_name in can_optimize: for var_name in can_optimize:
cache = (var_name, self._find_var(block_desc, var_name, cache = (var_name, self._find_var(block_desc, var_name,
is_forward).shape()) is_forward).shape())
if cache not in self.pool: if cache not in self.pool and find_var_in_block(var_name):
self.pool.append(cache) i = 0
while i < len(self.pool):
mycache = self.pool[i]
mysize = mycache[1][0]
cache_size = cache[1][0]
if (mysize == -1 and cache_size == -1) or \
(mysize != -1 and cache_size != -1):
if comparator(mycache, cache):
i += 1
else:
break
elif mysize == -1 and cache_size != -1:
i += 1
elif mysize != -1 and cache_size == -1:
break
self.pool.insert(i, cache)
def _get_diff(self, a, b): def _get_diff(self, a, b):
u = a & b u = a & b
...@@ -229,7 +262,7 @@ class ControlFlowGraph(object): ...@@ -229,7 +262,7 @@ class ControlFlowGraph(object):
def _update_skip_opt_set(self): def _update_skip_opt_set(self):
for i in range(self.op_size): for i in range(self.op_size):
op = self._ops[i] op = self._ops[i]
if op.type() == "fill_constant" and op.attr("force_cpu") == True: if op.has_attr("force_cpu") and op.attr("force_cpu") == True:
self._skip_opt.update(op.output_arg_names()) self._skip_opt.update(op.output_arg_names())
def release_memory(self, skip_opt_set=None): def release_memory(self, skip_opt_set=None):
...@@ -281,6 +314,7 @@ class ControlFlowGraph(object): ...@@ -281,6 +314,7 @@ class ControlFlowGraph(object):
# update skip set to meet users' demand # update skip set to meet users' demand
if skip_opt_set: if skip_opt_set:
self._skip_opt.update(skip_opt_set) self._skip_opt.update(skip_opt_set)
counter = 0
for i in range(self.op_size): for i in range(self.op_size):
op = self._ops[i] op = self._ops[i]
if op.type() in SUB_BLOCK_OPS: if op.type() in SUB_BLOCK_OPS:
...@@ -301,6 +335,9 @@ class ControlFlowGraph(object): ...@@ -301,6 +335,9 @@ class ControlFlowGraph(object):
# If x is both in uses and defs, it can not be optimized! # If x is both in uses and defs, it can not be optimized!
if x in self._uses[i]: if x in self._uses[i]:
continue continue
if x == FLAGS_memory_optimize:
print("start match var ", x, " of op ", op.type())
print(self.pool)
for index, cache_pair in enumerate(self.pool): for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0] cache_var = cache_pair[0]
cache_shape = cache_pair[1] cache_shape = cache_pair[1]
...@@ -323,15 +360,13 @@ class ControlFlowGraph(object): ...@@ -323,15 +360,13 @@ class ControlFlowGraph(object):
if not compare_shape(x_shape, cache_shape, level): if not compare_shape(x_shape, cache_shape, level):
continue continue
# TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
if x_dtype != cache_dtype:
continue
if PRINT_LOG: if PRINT_LOG:
print(("Hit Cache !!!! cache pool index " print(
"is %d, var name is %s, " ("!!! %d, %s => %s, cache idx %d, pool size %d"
"cached var name is %s, " % (counter, x + str(x_shape),
"var shape is %s ") % (index, x, cache_var, cache_var + str(cache_shape), index,
str(cache_shape))) len(self.pool))))
counter += 1
self.pool.pop(index) self.pool.pop(index)
# Rename the var to the cache var already with # Rename the var to the cache var already with
# memory allocated in order to reuse the memory. # memory allocated in order to reuse the memory.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册