提交 2739096e 编写于 作者: D dzhwinter

compatibable with python side mem_opt

上级 8f3b2523
......@@ -51,7 +51,8 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info)
cc_library(graph_print_pass SRCS graph_print_pass.cc DEPS graph_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info graph_print_pass)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
......@@ -72,6 +73,7 @@ if (WITH_GPU)
endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_test(graph_print_pass_test SRCS graph_print_pass_test.cc DEPS graph_print_pass framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......@@ -96,4 +98,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass)
memory_optimize_pass lock_free_optimize_pass graph_print_pass)
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
......@@ -43,8 +44,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_inplace_) {
// before inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "before_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(kGraphvizPath, new std::string(path));
// }
AppendPass("inplace_pass");
// after inplaced
// if (!strategy_.debug_graphviz_path_.empty()) {
// const std::string path = strategy_.debug_graphviz_path_ +
// "after_inplaced";
// auto pass = AppendPass("graph_print_pass");
// pass->Set<std::string>(details::kGraphvizPath, new
// std::string(path));
// }
}
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
......@@ -189,6 +207,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "memory_optimize_pass") {
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
......@@ -219,6 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs);
}
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
graph->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
......@@ -228,6 +252,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "graph_print_path") {
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
}
graph = pass->Apply(std::move(graph));
}
......@@ -253,3 +281,4 @@ USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_print_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
class GraphvizVar : public GraphvizNode {
public:
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
<< std::endl;
return sout;
}
};
class GraphvizOp : public GraphvizNode {
public:
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
<< "\", shape=rect]" << std::endl;
PADDLE_ENFORCE(op.stream_.rdbuf()->in_avail() != 0,
"No inputs outputs. Please call AddEdge first!");
sout << op.stream_.str();
return sout;
}
template <typename Callback>
void AddEdge(const Callback& cb) {
std::string op_name = "op_" + std::to_string(id_);
for (auto var : node_->inputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << var_name << "->" << op_name << std::endl;
}
for (auto var : node_->outputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << op_name << "->" << var_name << std::endl;
}
}
private:
std::ostringstream stream_;
};
template <typename T, typename Container>
std::vector<T*> FilterByNodeWrapper(const Container& con) {
std::vector<T*> ret;
for (auto& node : con) {
auto i = dynamic_cast<T*>(node.get());
if (i != nullptr) ret.emplace_back(i);
}
return ret;
}
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
const ir::Graph& graph) const {
// Convert to GraphvizNode format
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
graphviz_nodes.clear();
std::unordered_map<ir::Node*, int> vars;
int var_id = 0;
int op_id = 0;
for (auto& node : graph.Nodes()) {
if (node->IsVar()) {
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
vars.emplace(std::make_pair(node, var_id++));
} else if (node->IsOp()) {
graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
} else {
PADDLE_THROW("Unknown op type");
}
}
return vars;
}
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
std::ostream& sout) const {
auto vars = ToGraphvizNode(graph);
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
sout << "digraph G {\n";
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
sout << *var;
}
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
sout << *op;
}
sout << "}\n";
}
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
printer_.reset(new SSAGraphPrinterImpl());
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
printer_->Print(*graph, *fout);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphviz[] = "graphviz";
class GraphvizNode {
public:
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
virtual ~GraphvizNode() = default;
protected:
ir::Node* node_;
int id_;
};
class GraphvizNode;
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class SSAGraphPrinterImpl : public SSAGraphPrinter {
public:
void Print(const ir::Graph& graph, std::ostream& sout) const override;
private:
std::unordered_map<ir::Node*, int> ToGraphvizNode(
const ir::Graph& graph) const;
};
class SSAGraphPrintPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
mutable std::unique_ptr<SSAGraphPrinter> printer_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker);
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
paddle::framework::SplitOpMaker);
/*
a @ b
c
d @ e
*/
using paddle::framework::ProgramDesc;
using paddle::framework::proto::VarType;
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a", "b"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"d", "e"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"d", "e"});
op->SetOutput("Out", {"d"});
}
return prog;
}
namespace paddle {
namespace framework {
namespace details {
TEST(SSAGraphPrinter, Normal) {
auto program = FillProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SplitOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "");
AddOutput("Out", "").AsDuplicable();
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
......@@ -21,6 +21,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -76,42 +77,92 @@ namespace paddle {
namespace framework {
namespace details {
static inline ir::Node* GetNextInplacedOpOutput(ir::Node* var) {
static inline std::string NodeDebugString(ir::Node* var) {
std::ostringstream os;
if (var->IsCtrlVar()) {
os << "kControlDepVarName"
<< " ";
} else if (var->IsOp()) {
os << "kOperation"
<< " " << var->Name();
PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name());
} else if (var->IsVar()) {
os << "kVariable"
<< " " << var->Name();
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name());
} else {
PADDLE_THROW("Unknown node type.");
}
return os.str();
}
static inline std::string OpDebugString(ir::Node* var) {
ir::Node* op = var;
if (var->IsVar()) op = var->inputs.at(0);
std::stringstream os;
os << op->Name() << " : ";
os << "Input ";
VLOG(3) << op->Name();
for (auto* var : op->inputs) {
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
}
os << "Output ";
VLOG(3) << op->Name();
for (auto* var : op->outputs) {
VLOG(3) << var;
VLOG(3) << var->Name();
if (!var->IsVar()) {
VLOG(3) << "error";
}
// VLOG(3) << var->Var()->Name();
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
if (var->Name() == "fc_10.tmp_0") {
VLOG(3) << NodeDebugString(var);
}
}
return os.str();
}
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr;
// only has one output op can be inplaced
if (var->outputs.size() == 1 && var->outputs[0]->IsOp()) {
auto* op = var->outputs[0];
for (auto* out_var : op->outputs) {
if (!out_var->IsVar() || out_var->IsCtrlVar() ||
out_var->Var() == nullptr)
continue;
if (out_var->Name() == var->Name()) {
inplaced_var = out_var;
break;
for (auto* next_op : var->outputs) {
for (auto* output : next_op->outputs) {
if (output->IsVar() && !output->IsCtrlVar() &&
output->Name() == var->Name()) {
inplaced_var = output;
}
}
}
return inplaced_var;
}
static inline ir::Node* GetPrevInplacedOpInput(ir::Node* var) {
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
ir::Node* inplaced_var = nullptr;
if (var->inputs.size() == 1 && var->inputs[0]->IsOp()) {
auto* op = var->inputs[0];
for (auto* in_var : op->inputs) {
if (!in_var->IsVar() || in_var->IsCtrlVar() || in_var->Var() == nullptr)
continue;
if (in_var->Name() == var->Name()) {
inplaced_var = in_var;
break;
}
}
}
return inplaced_var;
auto* prev_op = var->inputs.at(0);
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
[&](ir::Node* node) {
if (node->IsVar() && !node->IsCtrlVar() &&
node->Name() == var->Name()) {
return true;
} else {
return false;
}
});
return input_it == prev_op->inputs.end() ? nullptr : *input_it;
}
template <typename Container>
......@@ -166,12 +217,22 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_.Build(graph.get());
InitSSAGraphNodes();
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
for (auto* op : view_.AllOps()) {
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
}
graph->ResolveHazard(var_nodes_);
constexpr char graph_path[] = "ir_graph_inplaced.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
// for(auto* op : view_.AllOps()) {
// VLOG(3) << OpDebugString(op);
// }
return graph;
}
......@@ -179,7 +240,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
const std::string& cache_var,
const size_t& idx) const {
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
ir::Node* op = view_.AllOps()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
......@@ -203,14 +264,28 @@ void InplacePass::InplaceModifyVar(const std::string& var,
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = var_nodes_[cache_var].back();
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
......@@ -220,7 +295,6 @@ void InplacePass::InplaceModifyVar(const std::string& var,
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
......@@ -230,15 +304,14 @@ void InplacePass::InplaceModifyVar(const std::string& var,
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unsed var in graph
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
}
// release node of unused var in graph
for (auto* node : var_nodes_[var]) {
graph->RemoveNode(node);
}
var_nodes_.at(var).clear();
}
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
......@@ -260,6 +333,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor);
VLOG(3) << op->Name() << idx;
for (auto& pair : in_to_outs) {
auto& in_var_name = pair.first;
......@@ -286,6 +360,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
}
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
// VLOG(3) << "Out " << OpDebugString(op);
InplaceModifyDesc(out_var_name, in_var_name, idx);
InplaceModifyVar(out_var_name, in_var_name, idx, graph);
}
......@@ -319,7 +394,16 @@ ir::Node* GraphView::GetNodeByName(const std::string& name,
}
std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
return node->outputs;
// get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std::vector<ir::Node*> pending_ops;
ir::Node* p = node;
while (p != nullptr) {
pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end());
p = GetPrevCascadeInplacedVar(p);
}
return pending_ops;
}
void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); }
......@@ -354,14 +438,14 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
// get the ops with same output name
while (out != nullptr) {
out_var_set.emplace(out);
out = GetNextInplacedOpOutput(out);
out = GetNextCascadeInplacedVar(out);
}
// get ops with same input name
ir::Node* in = in_var;
while (in != nullptr) {
in_var_set.emplace(in);
in = GetPrevInplacedOpInput(in);
in = GetPrevCascadeInplacedVar(in);
}
// find if there is path with control dep var connect the in_var_set and
// out_var_set
......
......@@ -18,57 +18,13 @@
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
......@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
return prog;
}
template <typename Container>
inline static std::string DebugString(const Container& c) {
std::stringstream ss;
for (auto& item : c) {
ss << item << " ";
}
return ss.str();
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
......
......@@ -19,20 +19,12 @@
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/graph_print_pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const ir::Graph& graph, std::ostream& sout) const override;
......
......@@ -40,7 +40,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None,
use_parallel_executor=True,
use_reduce=False,
use_ir_memory_optimize=False,
use_ir_memory_optimize=True,
enable_inplace=True,
fuse_elewise_add_act_ops=False,
fuse_relu_depthwise_conv=False,
......@@ -61,64 +61,66 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed
loss = method(use_feed=feed_dict is not None)
if optimizer:
optimizer().minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
binary = compiler.CompiledProgram(main)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
print("%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)))
avg_last_loss_val = np.array(last_loss).mean()
avg_first_loss_val = np.array(first_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan(
float(avg_first_loss_val)):
sys.exit("got NaN loss, training failed.")
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
with open("program_model.txt", "w") as f:
f.write(str(main))
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
build_strategy.debug_graphviz_path = "debug_ir_graph_"
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
binary = compiler.CompiledProgram(main)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
print("%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)))
avg_last_loss_val = np.array(last_loss).mean()
avg_first_loss_val = np.array(first_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan(
float(avg_first_loss_val)):
sys.exit("got NaN loss, training failed.")
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase
def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(3):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestIrInplace(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=True,
memory_opt=False, # inplace is conflict with memory opt
use_ir_memory_optimize=ir_memory_optimize,
enable_inplace=enable_inplace)
def test_fc_with_batchnorm(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, False)
loss10 = self._fc_with_batchnorm(True, False)
loss01 = self._fc_with_batchnorm(False, True)
loss11 = self._fc_with_batchnorm(True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册