提交 2739096e 编写于 作者: D dzhwinter

compatibable with python side mem_opt

上级 8f3b2523
...@@ -51,7 +51,8 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -51,7 +51,8 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass) cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info) cc_library(graph_print_pass SRCS graph_print_pass.cc DEPS graph_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info graph_print_pass)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
...@@ -72,6 +73,7 @@ if (WITH_GPU) ...@@ -72,6 +73,7 @@ if (WITH_GPU)
endif() endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass) cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_test(graph_print_pass_test SRCS graph_print_pass_test.cc DEPS graph_print_pass framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
...@@ -96,4 +98,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -96,4 +98,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass) memory_optimize_pass lock_free_optimize_pass graph_print_pass)
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
...@@ -43,8 +44,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -43,8 +44,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_inplace_) { if (strategy_.enable_inplace_) {
// before inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "before_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(kGraphvizPath, new std::string(path));
// }
AppendPass("inplace_pass"); AppendPass("inplace_pass");
// after inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "after_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(details::kGraphvizPath, new
// std::string(path));
// }
} }
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
...@@ -189,6 +207,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -189,6 +207,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "memory_optimize_pass") { } else if (pass->Type() == "memory_optimize_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
const std::vector<OpDesc *> *all_op_descs = const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps()); new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs, graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
...@@ -219,6 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -219,6 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
if (graph->Has(kAllOpDescs)) { if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs); graph->Erase(kAllOpDescs);
} }
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
graph->Set<const std::vector<OpDesc *>>( graph->Set<const std::vector<OpDesc *>>(
kAllOpDescs, kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps())); new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
...@@ -228,6 +252,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -228,6 +252,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "graph_print_path") {
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
} }
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
...@@ -253,3 +281,4 @@ USE_PASS(all_reduce_deps_pass); ...@@ -253,3 +281,4 @@ USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass); USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass); USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_print_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
class GraphvizVar : public GraphvizNode {
public:
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
<< std::endl;
return sout;
}
};
class GraphvizOp : public GraphvizNode {
public:
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
<< "\", shape=rect]" << std::endl;
PADDLE_ENFORCE(op.stream_.rdbuf()->in_avail() != 0,
"No inputs outputs. Please call AddEdge first!");
sout << op.stream_.str();
return sout;
}
template <typename Callback>
void AddEdge(const Callback& cb) {
std::string op_name = "op_" + std::to_string(id_);
for (auto var : node_->inputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << var_name << "->" << op_name << std::endl;
}
for (auto var : node_->outputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << op_name << "->" << var_name << std::endl;
}
}
private:
std::ostringstream stream_;
};
template <typename T, typename Container>
std::vector<T*> FilterByNodeWrapper(const Container& con) {
std::vector<T*> ret;
for (auto& node : con) {
auto i = dynamic_cast<T*>(node.get());
if (i != nullptr) ret.emplace_back(i);
}
return ret;
}
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
const ir::Graph& graph) const {
// Convert to GraphvizNode format
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
graphviz_nodes.clear();
std::unordered_map<ir::Node*, int> vars;
int var_id = 0;
int op_id = 0;
for (auto& node : graph.Nodes()) {
if (node->IsVar()) {
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
vars.emplace(std::make_pair(node, var_id++));
} else if (node->IsOp()) {
graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
} else {
PADDLE_THROW("Unknown op type");
}
}
return vars;
}
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
std::ostream& sout) const {
auto vars = ToGraphvizNode(graph);
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
sout << "digraph G {\n";
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
sout << *var;
}
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
sout << *op;
}
sout << "}\n";
}
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
printer_.reset(new SSAGraphPrinterImpl());
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
printer_->Print(*graph, *fout);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphviz[] = "graphviz";
class GraphvizNode {
public:
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
virtual ~GraphvizNode() = default;
protected:
ir::Node* node_;
int id_;
};
class GraphvizNode;
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class SSAGraphPrinterImpl : public SSAGraphPrinter {
public:
void Print(const ir::Graph& graph, std::ostream& sout) const override;
private:
std::unordered_map<ir::Node*, int> ToGraphvizNode(
const ir::Graph& graph) const;
};
class SSAGraphPrintPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
mutable std::unique_ptr<SSAGraphPrinter> printer_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker);
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
paddle::framework::SplitOpMaker);
/*
a @ b
c
d @ e
*/
using paddle::framework::ProgramDesc;
using paddle::framework::proto::VarType;
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a", "b"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"d", "e"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"d", "e"});
op->SetOutput("Out", {"d"});
}
return prog;
}
namespace paddle {
namespace framework {
namespace details {
TEST(SSAGraphPrinter, Normal) {
auto program = FillProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SplitOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "");
AddOutput("Out", "").AsDuplicable();
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -76,42 +77,92 @@ namespace paddle { ...@@ -76,42 +77,92 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static inline ir::Node* GetNextInplacedOpOutput(ir::Node* var) { static inline std::string NodeDebugString(ir::Node* var) {
std::ostringstream os;
if (var->IsCtrlVar()) {
os << "kControlDepVarName"
<< " ";
} else if (var->IsOp()) {
os << "kOperation"
<< " " << var->Name();
PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name());
} else if (var->IsVar()) {
os << "kVariable"
<< " " << var->Name();
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name());
} else {
PADDLE_THROW("Unknown node type.");
}
return os.str();
}
static inline std::string OpDebugString(ir::Node* var) {
ir::Node* op = var;
if (var->IsVar()) op = var->inputs.at(0);
std::stringstream os;
os << op->Name() << " : ";
os << "Input ";
VLOG(3) << op->Name();
for (auto* var : op->inputs) {
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
}
os << "Output ";
VLOG(3) << op->Name();
for (auto* var : op->outputs) {
VLOG(3) << var;
VLOG(3) << var->Name();
if (!var->IsVar()) {
VLOG(3) << "error";
}
// VLOG(3) << var->Var()->Name();
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
if (var->Name() == "fc_10.tmp_0") {
VLOG(3) << NodeDebugString(var);
}
}
return os.str();
}
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
// if next op is inplaced, then return the output var // if next op is inplaced, then return the output var
// otherwise return nullptr // otherwise return nullptr
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr; ir::Node* inplaced_var = nullptr;
// only has one output op can be inplaced for (auto* next_op : var->outputs) {
if (var->outputs.size() == 1 && var->outputs[0]->IsOp()) { for (auto* output : next_op->outputs) {
auto* op = var->outputs[0]; if (output->IsVar() && !output->IsCtrlVar() &&
for (auto* out_var : op->outputs) { output->Name() == var->Name()) {
if (!out_var->IsVar() || out_var->IsCtrlVar() || inplaced_var = output;
out_var->Var() == nullptr)
continue;
if (out_var->Name() == var->Name()) {
inplaced_var = out_var;
break;
} }
} }
} }
return inplaced_var; return inplaced_var;
} }
static inline ir::Node* GetPrevInplacedOpInput(ir::Node* var) { static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr; auto* prev_op = var->inputs.at(0);
if (var->inputs.size() == 1 && var->inputs[0]->IsOp()) { auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
auto* op = var->inputs[0]; [&](ir::Node* node) {
for (auto* in_var : op->inputs) { if (node->IsVar() && !node->IsCtrlVar() &&
if (!in_var->IsVar() || in_var->IsCtrlVar() || in_var->Var() == nullptr) node->Name() == var->Name()) {
continue; return true;
if (in_var->Name() == var->Name()) { } else {
inplaced_var = in_var; return false;
break; }
} });
} return input_it == prev_op->inputs.end() ? nullptr : *input_it;
}
return inplaced_var;
} }
template <typename Container> template <typename Container>
...@@ -166,12 +217,22 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl( ...@@ -166,12 +217,22 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_.Build(graph.get()); view_.Build(graph.get());
InitSSAGraphNodes(); InitSSAGraphNodes();
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
for (auto* op : view_.AllOps()) { for (auto* op : view_.AllOps()) {
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue; continue;
TryInplaceOpInputOutput(op, graph.get()); TryInplaceOpInputOutput(op, graph.get());
} }
graph->ResolveHazard(var_nodes_); graph->ResolveHazard(var_nodes_);
constexpr char graph_path[] = "ir_graph_inplaced.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
// for(auto* op : view_.AllOps()) {
// VLOG(3) << OpDebugString(op);
// }
return graph; return graph;
} }
...@@ -179,7 +240,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var, ...@@ -179,7 +240,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
const std::string& cache_var, const std::string& cache_var,
const size_t& idx) const { const size_t& idx) const {
for (size_t i = idx; i < view_.AllOps().size(); ++i) { for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i]; ir::Node* op = view_.AllOps()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op()); PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op(); auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var); op_desc->RenameInput(var, cache_var);
...@@ -203,14 +264,28 @@ void InplacePass::InplaceModifyVar(const std::string& var, ...@@ -203,14 +264,28 @@ void InplacePass::InplaceModifyVar(const std::string& var,
// redirect the input to the latest version of cache_var // redirect the input to the latest version of cache_var
for (auto* node : op->inputs) { for (auto* node : op->inputs) {
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = var_nodes_[cache_var].back(); ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node // swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end()); node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) { for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
} }
} }
...@@ -220,7 +295,6 @@ void InplacePass::InplaceModifyVar(const std::string& var, ...@@ -220,7 +295,6 @@ void InplacePass::InplaceModifyVar(const std::string& var,
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node); var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node // swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end()); node->outputs.begin(), node->outputs.end());
...@@ -230,15 +304,14 @@ void InplacePass::InplaceModifyVar(const std::string& var, ...@@ -230,15 +304,14 @@ void InplacePass::InplaceModifyVar(const std::string& var,
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
// release unsed var in graph
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
} }
} }
} }
// release node of unused var in graph
for (auto* node : var_nodes_[var]) {
graph->RemoveNode(node);
}
var_nodes_.at(var).clear();
} }
void InplacePass::TryInplaceOpInputOutput(ir::Node* op, void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
...@@ -260,6 +333,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -260,6 +333,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto& all_ops = view_.AllOps(); auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op); auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor); size_t idx = std::distance(all_ops.begin(), cursor);
VLOG(3) << op->Name() << idx;
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto& in_var_name = pair.first; auto& in_var_name = pair.first;
...@@ -286,6 +360,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -286,6 +360,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
} }
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name); out_var_name, in_var_name);
// VLOG(3) << "Out " << OpDebugString(op);
InplaceModifyDesc(out_var_name, in_var_name, idx); InplaceModifyDesc(out_var_name, in_var_name, idx);
InplaceModifyVar(out_var_name, in_var_name, idx, graph); InplaceModifyVar(out_var_name, in_var_name, idx, graph);
} }
...@@ -319,7 +394,16 @@ ir::Node* GraphView::GetNodeByName(const std::string& name, ...@@ -319,7 +394,16 @@ ir::Node* GraphView::GetNodeByName(const std::string& name,
} }
std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) { std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
return node->outputs; // get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std::vector<ir::Node*> pending_ops;
ir::Node* p = node;
while (p != nullptr) {
pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end());
p = GetPrevCascadeInplacedVar(p);
}
return pending_ops;
} }
void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); } void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); }
...@@ -354,14 +438,14 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { ...@@ -354,14 +438,14 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
// get the ops with same output name // get the ops with same output name
while (out != nullptr) { while (out != nullptr) {
out_var_set.emplace(out); out_var_set.emplace(out);
out = GetNextInplacedOpOutput(out); out = GetNextCascadeInplacedVar(out);
} }
// get ops with same input name // get ops with same input name
ir::Node* in = in_var; ir::Node* in = in_var;
while (in != nullptr) { while (in != nullptr) {
in_var_set.emplace(in); in_var_set.emplace(in);
in = GetPrevInplacedOpInput(in); in = GetPrevCascadeInplacedVar(in);
} }
// find if there is path with control dep var connect the in_var_set and // find if there is path with control dep var connect the in_var_set and
// out_var_set // out_var_set
......
...@@ -18,57 +18,13 @@ ...@@ -18,57 +18,13 @@
#include <iterator> #include <iterator>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp, REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker, paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference); paddle::framework::DummyVarTypeInference);
...@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() { ...@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
return prog; return prog;
} }
template <typename Container>
inline static std::string DebugString(const Container& c) {
std::stringstream ss;
for (auto& item : c) {
ss << item << " ";
}
return ss.str();
}
TEST(CFGGraph, IRGraph) { TEST(CFGGraph, IRGraph) {
// prepare ir graph // prepare ir graph
auto prog = FillProgramDesc(); auto prog = FillProgramDesc();
......
...@@ -19,20 +19,12 @@ ...@@ -19,20 +19,12 @@
#include <iosfwd> #include <iosfwd>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/graph_print_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const ir::Graph& graph, std::ostream& sout) const override; void Print(const ir::Graph& graph, std::ostream& sout) const override;
......
...@@ -40,7 +40,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -40,7 +40,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True,
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=False, use_ir_memory_optimize=True,
enable_inplace=True, enable_inplace=True,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
fuse_relu_depthwise_conv=False, fuse_relu_depthwise_conv=False,
...@@ -61,64 +61,66 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -61,64 +61,66 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed main.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
if optimizer: if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
if memory_opt: if memory_opt:
fluid.memory_optimize(main) fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() with open("program_model.txt", "w") as f:
exe = fluid.Executor(place) f.write(str(main))
exe.run(startup) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exec_strategy = fluid.ExecutionStrategy() exe = fluid.Executor(place)
exec_strategy.allow_op_delay = allow_op_delay exe.run(startup)
if use_fast_executor: exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True exec_strategy.allow_op_delay = allow_op_delay
build_strategy = fluid.BuildStrategy() if use_fast_executor:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ exec_strategy.use_experimental_executor = True
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_inplace = enable_inplace build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.memory_optimize = use_ir_memory_optimize
if use_cuda and core.is_compiled_with_cuda(): build_strategy.enable_inplace = enable_inplace
build_strategy.remove_unnecessary_lock = True build_strategy.enable_sequential_execution = enable_sequential_execution
if use_parallel_executor: build_strategy.debug_graphviz_path = "debug_ir_graph_"
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, if use_cuda and core.is_compiled_with_cuda():
build_strategy=build_strategy, build_strategy.remove_unnecessary_lock = True
exec_strategy=exec_strategy) if use_parallel_executor:
else: binary = compiler.CompiledProgram(main).with_data_parallel(
binary = compiler.CompiledProgram(main) loss_name=loss.name,
build_strategy=build_strategy,
if batch_size is not None: exec_strategy=exec_strategy)
batch_size *= fluid.core.get_cuda_device_count( else:
) if use_cuda else int( binary = compiler.CompiledProgram(main)
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time() if batch_size is not None:
first_loss, = run_executor( batch_size *= fluid.core.get_cuda_device_count(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) ) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
for i in range(iter): begin = time.time()
run_executor( first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
last_loss, = run_executor( for i in range(iter):
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
end = time.time()
last_loss, = run_executor(
if batch_size is not None: exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
print("%.4f Instance per second" % ( end = time.time()
(batch_size * iter + 2) / (end - begin)))
if batch_size is not None:
avg_last_loss_val = np.array(last_loss).mean() print("%.4f Instance per second" % (
avg_first_loss_val = np.array(first_loss).mean() (batch_size * iter + 2) / (end - begin)))
if math.isnan(float(avg_last_loss_val)) or math.isnan(
float(avg_first_loss_val)): avg_last_loss_val = np.array(last_loss).mean()
sys.exit("got NaN loss, training failed.") avg_first_loss_val = np.array(first_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan(
print(first_loss, last_loss) float(avg_first_loss_val)):
# self.assertGreater(first_loss[0], last_loss[0]) sys.exit("got NaN loss, training failed.")
return first_loss, last_loss
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase
def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(3):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestIrInplace(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=True,
memory_opt=False, # inplace is conflict with memory opt
use_ir_memory_optimize=ir_memory_optimize,
enable_inplace=enable_inplace)
def test_fc_with_batchnorm(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, False)
loss10 = self._fc_with_batchnorm(True, False)
loss01 = self._fc_with_batchnorm(False, True)
loss11 = self._fc_with_batchnorm(True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册