未验证 提交 8046e33d 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add some passes which can be applied to Program (#34730)

* add inplace passes and tests

* update

* fix use_cuda undefined
fix compile error of op compat

* add more ut

* fix CPU CI error

* check adam unique

* fix mac/windows ci, improve coverage

* fix ci error

* follow weihang's comment

* fix BlockDesc::MoveFrom

* follow qiuliang's comment

* update

* follow huihuang's comments
上级 5de576b0
......@@ -238,5 +238,41 @@ BlockDesc *BlockDesc::ForwardBlock() const {
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
}
void BlockDesc::MoveFrom(BlockDesc *block) {
PADDLE_ENFORCE_NOT_NULL(
block, platform::errors::InvalidArgument("Block must be provided."));
if (this == block) {
return;
}
for (auto &pair : block->vars_) {
const auto &name = pair.first;
auto &var_ptr = pair.second;
auto &old_var_ptr = vars_[name];
if (old_var_ptr == nullptr) {
VLOG(10) << "Create new variable " << var_ptr->Name();
old_var_ptr = std::move(var_ptr);
} else {
// NOTE(zjl): cannot release old_var_ptr, because Python
// Variable holds the reference of the C++ VarDesc object.
// If the C++ VarDesc object is destructed, any call to the
// methods of Python Variable may raise segmentation fault.
VLOG(10) << "Update old variable " << var_ptr->Name();
*old_var_ptr = *var_ptr;
}
}
ops_.clear();
for (const auto &src_op : block->ops_) {
AppendOp()->CopyFrom(*src_op);
}
need_update_ = true;
Flush();
block->ops_.clear();
block->vars_.clear();
block->need_update_ = true;
block->Flush();
}
} // namespace framework
} // namespace paddle
......@@ -111,6 +111,8 @@ class BlockDesc {
ProgramDesc *Program() const { return this->prog_; }
void MoveFrom(BlockDesc *block);
private:
ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own
......
......@@ -180,6 +180,11 @@ struct BuildStrategy {
bool IsFinalized() const { return is_finalized_; }
void ClearFinalized() {
pass_builder_ = nullptr;
is_finalized_ = false;
}
bool IsMultiDevPass(const std::string &pass_name) const;
// Apply the passes built by the pass_builder_. The passes will be
......
......@@ -50,7 +50,7 @@ if (WITH_TESTING)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api)
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api pass)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
......
......@@ -10,7 +10,7 @@ cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_h
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)
......
......@@ -15,6 +15,7 @@
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -30,6 +31,9 @@ class BufferSharedInplaceOpPass : public MemoryReusePass {
std::string ReuseType() const override { return "inplace"; }
void Run(Graph *graph) const override;
void ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const override;
};
void BufferSharedInplaceOpPass::Run(Graph *graph) const {
......@@ -149,6 +153,141 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
}
}
static std::string GetFirstVarName(const OpDesc &op, const std::string &slot,
bool is_input) {
const auto &name_map = is_input ? op.Inputs() : op.Outputs();
auto iter = name_map.find(slot);
if (iter != name_map.end() && !iter->second.empty()) {
return iter->second[0];
}
return kEmptyVarName;
}
static std::vector<std::vector<std::pair<std::string, std::string>>>
GetInplaceVars(const BlockDesc &block, bool use_cuda,
const std::vector<std::string> &skip_vars) {
PADDLE_ENFORCE_EQ(block.ID(), 0, platform::errors::Unimplemented(
"Inplace can only perform in block 0."));
// only take block 0 gc_vars
const auto op_gc_vars =
GetEagerDeletionCleanVars(*block.Program(), skip_vars)[0];
const auto all_ops = block.AllOps();
PADDLE_ENFORCE_EQ(op_gc_vars.size(), all_ops.size(),
platform::errors::PermissionDenied(
"GC analysis error: op number not match."));
size_t n = all_ops.size();
std::unordered_set<std::string> visited_vars;
std::unordered_set<std::string> reused_in_vars(skip_vars.begin(),
skip_vars.end());
std::unordered_set<std::string> reused_out_vars(skip_vars.begin(),
skip_vars.end());
for (const auto *op : all_ops) {
if (op->Type() == "share_buffer" || op->Type() == "share_data") {
const auto &inputs = op->Input("X");
const auto &outputs = op->Output("Out");
reused_in_vars.insert(inputs.begin(), inputs.end());
reused_out_vars.insert(outputs.begin(), outputs.end());
}
}
std::vector<std::vector<std::pair<std::string, std::string>>> result(n);
for (size_t i = 0; i < n; ++i) {
const auto &op = *all_ops[i];
const auto &gc_vars = op_gc_vars[i];
const auto inputs = op.InputArgumentNames();
const auto outputs = op.OutputArgumentNames();
visited_vars.insert(inputs.begin(), inputs.end());
auto &infer_inplace = OpInfoMap::Instance().Get(op.Type()).infer_inplace_;
if (gc_vars.empty() || !infer_inplace) {
visited_vars.insert(outputs.begin(), outputs.end());
continue;
}
const auto var_pair = infer_inplace(use_cuda);
std::unordered_multiset<std::string> input_set(inputs.begin(),
inputs.end());
std::unordered_multiset<std::string> output_set(outputs.begin(),
outputs.end());
std::unordered_set<std::string> valid_vars;
for (const auto &var : gc_vars) {
if (var != kEmptyVarName && input_set.count(var) == 1 &&
output_set.count(var) == 0 &&
block.FindVar(var)->GetType() == proto::VarType::LOD_TENSOR) {
valid_vars.insert(var);
}
}
if (valid_vars.empty()) {
visited_vars.insert(outputs.begin(), outputs.end());
continue;
}
for (const auto &pair : var_pair) {
const auto &input_slot = pair.first;
const auto &output_slot = pair.second;
auto input_var = GetFirstVarName(op, input_slot, /*is_input=*/true);
if (input_var == kEmptyVarName || valid_vars.count(input_var) == 0) {
continue;
}
auto output_var = GetFirstVarName(op, output_slot, /*is_input=*/false);
if (output_var == kEmptyVarName || visited_vars.count(output_var) > 0) {
continue;
}
auto output_var_desc = block.FindVar(output_var);
if (output_var_desc == nullptr || output_var_desc->Persistable() ||
output_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
continue;
}
if (reused_in_vars.count(input_var) > 0 ||
reused_out_vars.count(output_var) > 0) {
continue;
}
// input_var -> output_var is reusable
VLOG(10) << "inplace occurs at op " << i << " " << op.Type() << ": "
<< input_var << " -> " << output_var;
result[i].emplace_back(input_var, output_var);
reused_in_vars.insert(input_var);
reused_out_vars.insert(output_var);
}
visited_vars.insert(outputs.begin(), outputs.end());
std::sort(result[i].begin(), result[i].end());
}
return result;
}
void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
bool use_cuda = Get<bool>(kUseCuda);
auto skip_vars = Get<std::vector<std::string>>("mem_opt_skip_vars");
auto *block = main_program->MutableBlock(0);
auto inplace_vars = GetInplaceVars(*block, use_cuda, skip_vars);
PADDLE_ENFORCE_EQ(inplace_vars.size(), block->OpSize(),
platform::errors::PermissionDenied(
"Inplace analysis error: op number not match."));
int64_t n = static_cast<int64_t>(inplace_vars.size());
for (int64_t i = n - 1; i >= 0; --i) {
if (inplace_vars[i].empty()) continue;
auto *op = block->InsertOp(i);
std::vector<std::string> inputs, outputs;
inputs.reserve(inplace_vars[i].size());
outputs.reserve(inplace_vars[i].size());
for (const auto &pair : inplace_vars[i]) {
inputs.push_back(pair.first);
outputs.push_back(pair.second);
}
op->SetType("share_buffer");
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetOutput("XOut", inputs); // add necessary dependency
op->SetAttr("share_dims", std::vector<bool>(inputs.size(), false));
}
block->Flush();
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
......@@ -40,6 +41,9 @@ class InplaceAddToOpPass : public MemoryReusePass {
void Run(Graph *graph) const override;
void ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const override;
private:
// 1. Add last living op of in_var, add any last living op of out_var
// 2. Set reference count of in_var to be 2
......@@ -216,6 +220,264 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
}
}
static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
}
static bool IsDownstreamNode(const Node &upstream, const Node &downstream) {
std::queue<const Node *> q;
std::unordered_set<const Node *> visited;
q.push(&upstream);
visited.insert(&upstream);
while (!q.empty()) {
const auto *cur = q.front();
q.pop();
if (cur == &downstream) {
return true;
}
for (const auto *out : cur->outputs) {
if (visited.count(out) == 0) {
visited.insert(out);
q.push(out);
}
}
}
return false;
}
static void BuildInplaceAddToGraph(Node *in_var_0, Node *in_var_1,
Node *out_var, Graph *graph) {
auto *grad_add_op = out_var->inputs[0];
// Cut the connection between in_var_0 and grad_add_op
in_var_0->outputs.erase(std::remove(in_var_0->outputs.begin(),
in_var_0->outputs.end(), grad_add_op),
in_var_0->outputs.end());
grad_add_op->inputs.erase(std::remove(grad_add_op->inputs.begin(),
grad_add_op->inputs.end(), in_var_0),
grad_add_op->inputs.end());
// Replace grad_add_op with share_buffer op
auto *grad_add_op_desc = grad_add_op->Op();
grad_add_op_desc->SetType("share_buffer");
grad_add_op_desc->SetInput("X", {in_var_1->Name()});
grad_add_op_desc->SetOutput("Out", {out_var->Name()});
grad_add_op_desc->SetOutput("XOut", {in_var_1->Name()});
grad_add_op_desc->SetAttr("share_dims", std::vector<bool>(1, true));
// Add share_buffer op between in_var_0 and in_var_1
OpDesc share_buffer_op;
share_buffer_op.SetType("share_buffer");
share_buffer_op.SetInput("X", {in_var_0->Name()});
share_buffer_op.SetOutput("Out", {in_var_1->Name()});
share_buffer_op.SetOutput("XOut", {in_var_0->Name()});
share_buffer_op.SetAttr("share_dims", std::vector<bool>(1, false));
auto *new_share_buffer_op = graph->CreateOpNode(&share_buffer_op);
new_share_buffer_op->inputs.push_back(in_var_0);
in_var_0->outputs.push_back(new_share_buffer_op);
new_share_buffer_op->outputs.push_back(in_var_1);
in_var_1->inputs.push_back(new_share_buffer_op);
auto *dep_var = graph->CreateControlDepVar();
new_share_buffer_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(new_share_buffer_op);
auto in_var_1_gen_op = in_var_1->inputs[0];
in_var_1_gen_op->inputs.push_back(dep_var);
dep_var->outputs.push_back(in_var_1_gen_op);
in_var_1_gen_op->Op()->SetAttr("use_addto", true);
}
static std::unordered_map<std::string, std::vector<Node *>>
GetAllVersionVarsMap(const Graph &graph) {
const auto &nodes = graph.Nodes();
std::unordered_map<Node *, size_t> deps;
std::vector<Node *> sorted_nodes;
sorted_nodes.reserve(nodes.size());
std::queue<Node *> q;
for (auto *node : nodes) {
size_t in_degree = node->inputs.size();
if (in_degree == 0) {
q.push(node);
sorted_nodes.push_back(node);
} else {
deps[node] = node->inputs.size();
}
}
while (!q.empty()) {
auto *cur = q.front();
q.pop();
for (auto *node : cur->outputs) {
if (--deps.at(node) == 0) {
sorted_nodes.push_back(node);
q.push(node);
}
}
}
PADDLE_ENFORCE_EQ(
sorted_nodes.size(), nodes.size(),
platform::errors::PermissionDenied("Wrong toplogical sort algorithm."));
std::unordered_map<std::string, std::vector<Node *>> result;
for (auto *node : sorted_nodes) {
if (node->IsVar() && !node->IsCtrlVar()) {
result[node->Name()].push_back(node);
}
}
return result;
}
void InplaceAddToOpPass::ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
if (!Get<bool>(kUseCuda)) return;
Graph graph(*main_program);
auto all_ver_vars = GetAllVersionVarsMap(graph);
const auto all_nodes = graph.Nodes(); // Deep copy
std::unordered_set<std::string> reused_in_vars;
std::unordered_set<std::string> reused_out_vars;
for (auto *node : all_nodes) {
if (!node->IsOp() || node->Op() == nullptr ||
node->Op()->Type() != "grad_add") {
continue;
}
VLOG(10) << "Found grad_add op";
// Step 1: find input vars first
std::vector<Node *> input_vars;
input_vars.reserve(2);
for (auto *in : node->inputs) {
if (in->IsCtrlVar() || in->Name() == kEmptyVarName) {
continue;
}
PADDLE_ENFORCE_LT(input_vars.size(), 2,
platform::errors::InvalidArgument(
"The size of inputs of grad_add should be 2."));
input_vars.push_back(in);
}
if (input_vars.size() != 2) { // may have kEmptyVarName
continue;
}
bool is_first_var_valid = IsValidConv2DGradDataGradNode(*input_vars[0]);
bool is_second_var_valid = IsValidConv2DGradDataGradNode(*input_vars[1]);
if (!is_first_var_valid && !is_second_var_valid) {
continue;
}
VLOG(10) << "validation " << is_first_var_valid << " "
<< is_second_var_valid;
// make sure that input_vars[1] is always the Input@GRAD of conv2d_grad op
if (is_first_var_valid) {
std::swap(input_vars[0], input_vars[1]);
}
// Step 2: find the unique output var
Node *output_var = nullptr;
std::string output_var_name = node->Op()->Output("Out")[0];
PADDLE_ENFORCE_NE(output_var_name, kEmptyVarName,
platform::errors::InvalidArgument(
"Output of grad_add should be provided."));
for (auto *out : node->outputs) {
if (output_var_name == out->Name()) {
output_var = out;
break;
}
}
PADDLE_ENFORCE_NOT_NULL(output_var,
platform::errors::InvalidArgument(
"Output of grad_add should be provided."));
VLOG(10) << "Check inplace chain: " << input_vars[0]->Name() << " -> "
<< input_vars[1]->Name() << " -> " << output_var->Name();
// Step 3: check whether input_vars[0]->generated_op is not the downstream
// op of input_vars[0]->generated_op. If yes, circle would occur.
if (!input_vars[0]->inputs.empty() && !input_vars[1]->inputs.empty()) {
auto *gen_op_0 = input_vars[0]->inputs[0];
auto *gen_op_1 = input_vars[1]->inputs[0];
if (IsDownstreamNode(*gen_op_1, *gen_op_0)) {
VLOG(10) << "Downstream node detected, cannot inplace addto";
continue;
}
}
// Step 4: name not the same
if (input_vars[0]->Name() == input_vars[1]->Name() ||
input_vars[0]->Name() == output_var->Name() ||
input_vars[1]->Name() == output_var->Name()) {
continue;
}
// Step 5: check var version. The inplace var chain is: input_vars[0] ->
// input_vars[1] -> output_var
// Therefore, input_vars[0] must be last version, input_vars[1] must be 1st
// version and last version, and output_var must be the 1st version.
auto iter = all_ver_vars.find(input_vars[0]->Name());
PADDLE_ENFORCE_EQ(iter != all_ver_vars.end(), true,
platform::errors::InvalidArgument(
"Variable %s not found.", input_vars[0]->Name()));
if (iter->second[iter->second.size() - 1] != input_vars[0]) continue;
iter = all_ver_vars.find(input_vars[1]->Name());
if (iter->second.size() != 1) continue;
PADDLE_ENFORCE_EQ(iter->second[0], input_vars[1],
platform::errors::InvalidArgument(
"Variable %s not found.", input_vars[1]->Name()));
iter = all_ver_vars.find(output_var->Name());
if (iter->second[0] != output_var) continue;
// Step 6: input_vars[0] and input_vars[1] should only have one output op!
// This output op must be grad_add op.
if (input_vars[0]->outputs.size() != 1 ||
input_vars[1]->outputs.size() != 1) {
continue;
}
// Step 7: check whether the var has been reused
if (reused_in_vars.count(input_vars[0]->Name()) > 0 ||
reused_in_vars.count(input_vars[1]->Name()) > 0 ||
reused_out_vars.count(input_vars[1]->Name()) > 0 ||
reused_out_vars.count(output_var->Name()) > 0) {
continue;
}
VLOG(10) << "inplace occurs at " << input_vars[0]->Name() << " -> "
<< input_vars[1]->Name() << " -> " << output_var->Name();
// Step 8: inplace addto can be performed now!
BuildInplaceAddToGraph(input_vars[0], input_vars[1], output_var, &graph);
reused_in_vars.insert(input_vars[0]->Name());
reused_in_vars.insert(input_vars[1]->Name());
reused_out_vars.insert(input_vars[1]->Name());
reused_out_vars.insert(output_var->Name());
}
// Convert Graph to main_program
ProgramDesc tmp;
GraphToProgram(graph, &tmp);
main_program->CopyFrom(*tmp.Proto());
main_program->Flush();
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
......@@ -31,17 +32,18 @@ namespace paddle {
namespace framework {
namespace ir {
Graph* Pass::Apply(Graph* graph) const {
Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass();
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
for (const std::string& attr : required_pass_attrs_) {
for (const std::string &attr : required_pass_attrs_) {
PADDLE_ENFORCE_NE(
attrs_.find(attr), attrs_.end(),
platform::errors::InvalidArgument(
"Required atrribute %s for pass < %s > is not set.", attr, Type()));
}
for (const std::string& attr : required_graph_attrs_) {
for (const std::string &attr : required_graph_attrs_) {
PADDLE_ENFORCE_EQ(graph->Has(attr), true,
platform::errors::InvalidArgument(
"Required atrribute %s for graph is not set.", attr));
......@@ -66,30 +68,103 @@ Graph* Pass::Apply(Graph* graph) const {
// Passes can change params, tensors, so caching need to be discarded
ClearMKLDNNCache(paddle::platform::CPUPlace());
#endif
VLOG(10) << "finish to apply pass " << Type() << " to graph";
return graph;
}
void Pass::Apply(ProgramDesc* main_program,
ProgramDesc* startup_program) const {
void Pass::Apply(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
VLOG(10) << "apply pass " << Type() << " to program";
PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument(
"main program must be provided"));
PADDLE_ENFORCE_NOT_NULL(
startup_program,
platform::errors::InvalidArgument("startup program must be provided"));
ApplyImpl(main_program, startup_program);
VLOG(10) << "finish to apply pass " << Type() << " to program";
}
template <typename Container, typename Visitor>
static void VisitAllElements(Container &&container, Visitor &&visitor,
bool reverse) {
if (reverse) {
std::for_each(container.rbegin(), container.rend(), visitor);
} else {
std::for_each(container.begin(), container.end(), visitor);
}
}
void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
auto create_var_visitor = [dst](const ProgramDesc &src) {
PADDLE_ENFORCE_EQ(src.Size(), 1, platform::errors::Unimplemented(
"MergePrograms can only support to "
"merge program with only one block."));
const auto &src_block = src.Block(0);
auto *dst_block = dst->MutableBlock(0);
for (const auto *src_new_var : src_block.AllVars()) {
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
auto ops = src.Block(0).AllOps();
auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) {
auto *dst_block = dst->MutableBlock(0);
auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp();
op->CopyFrom(*src_op);
VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type();
// FIXME(zjl): some passes does not add VarDesc to program,
// we should fix this bug later...
for (const auto &in_var_name : op->InputArgumentNames()) {
dst_block->Var(in_var_name);
}
for (const auto &out_var_name : op->OutputArgumentNames()) {
dst_block->Var(out_var_name);
}
};
VisitAllElements(ops, copy_op_visitor, reverse);
};
VisitAllElements(srcs, create_op_visitor, reverse);
}
void Pass::ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
Graph graph(*main_program);
Apply(&graph);
// TODO(zjl): support details::kStartupProgramDescs and details::kProgramDescs
ProgramDesc new_main_program;
GraphToProgram(graph, &new_main_program);
main_program->CopyFrom(*new_main_program.Proto());
if (graph.Has(details::kStartupProgramDescs)) {
const auto &startups =
graph.Get<details::ProgramDescs>(details::kStartupProgramDescs);
VLOG(10) << "Merge startup programs";
MergePrograms(startup_program, startups, /*append=*/true);
}
if (graph.Has(details::kProgramDescs)) {
const auto &mains =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
VLOG(10) << "Merge main programs";
MergePrograms(main_program, mains, /*append=*/false);
}
startup_program->Flush();
main_program->Flush();
}
PassRegistry& PassRegistry::Instance() {
PassRegistry &PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}
......
......@@ -148,6 +148,12 @@ class Pass {
"The virtual pass called is not implemented."));
}
virtual void ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const;
static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
bool append);
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/share_buffer_op.h"
namespace paddle {
namespace operators {
class ShareBufferOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// dtype is not important
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
};
class ShareBufferOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensors of share buffer op")
.AsDuplicable();
AddOutput("Out", "(Tensor), The output tensors of share buffer op")
.AsDuplicable();
AddOutput("XOut",
"(Tensor), The output tensors which are the same as X. It is "
"used to build the graph dependency")
.AsDuplicable();
AddAttr<std::vector<bool>>("share_dims", "Whether to share dims")
.SetDefault(std::vector<bool>());
AddComment(
R"DOC(Operator used to perform inplace memory reuse. It should be not exposed to Python APIs.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(share_buffer, ops::ShareBufferOp, ops::ShareBufferOpMaker);
// dtype is not important
REGISTER_OP_CPU_KERNEL(share_buffer, ops::ShareBufferOpKernel<float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/share_buffer_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(share_buffer, ops::ShareBufferOpKernel<float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
template <typename T>
class ShareBufferOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto inputs = ctx.MultiInput<framework::Tensor>("X");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
size_t n = inputs.size();
PADDLE_ENFORCE_EQ(n, outputs.size(), platform::errors::PermissionDenied(
"Variable number not match."));
const auto &share_dims = ctx.Attr<std::vector<bool>>("share_dims");
if (!share_dims.empty()) {
PADDLE_ENFORCE_EQ(
n, share_dims.size(),
platform::errors::PermissionDenied(
"Attribute share_dims number not match input variable number."));
}
const std::vector<std::string> *input_args = nullptr,
*output_args = nullptr;
if (VLOG_IS_ON(10)) {
input_args = &ctx.GetOp().Inputs("X");
output_args = &ctx.GetOp().Outputs("Out");
}
for (size_t i = 0; i < n; ++i) {
if (inputs[i] == nullptr || outputs[i] == nullptr) {
continue;
}
outputs[i]->ShareBufferWith(*inputs[i]);
VLOG(10) << "Share tensor buffer " << (*input_args)[i] << " -> "
<< (*output_args)[i];
if (!share_dims.empty() && share_dims[i]) {
outputs[i]->Resize(inputs[i]->dims());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -301,6 +301,7 @@ void BindPass(py::module *m) {
// pass_attr_types to indicate the type of "nranks" explicitly,
// i.e. pass_attr_types = {"nranks": "size_t"} means that the type of
// "nranks" is size_t in C++.
REGISTER_PASS_ATTR_GETTER_SETTER("bool", bool);
REGISTER_PASS_ATTR_GETTER_SETTER("int", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("long", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("size_t", size_t);
......@@ -309,6 +310,7 @@ void BindPass(py::module *m) {
REGISTER_PASS_ATTR_GETTER_SETTER("float", double);
REGISTER_PASS_ATTR_GETTER_SETTER("bytes", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector<std::string>);
m->def(
"apply_pass",
......
......@@ -156,7 +156,8 @@ void BindBlockDesc(pybind11::module *m) {
pybind11::return_value_policy::reference)
.def("op_size", &pd::BlockDesc::OpSize)
.def("op", &pd::BlockDesc::Op, pybind11::return_value_policy::reference)
.def("serialize_to_string", SerializeMessage<pd::BlockDesc>);
.def("serialize_to_string", SerializeMessage<pd::BlockDesc>)
.def("_move_from", &pd::BlockDesc::MoveFrom);
}
void BindVarDsec(pybind11::module *m) {
......
......@@ -2553,6 +2553,7 @@ All parameter, weight, gradient are variables in Paddle.
.value("Customized", BuildStrategy::GradientScaleStrategy::kCustomized);
build_strategy.def(py::init())
.def("_clear_finalized", &BuildStrategy::ClearFinalized)
.def_property(
"reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; },
......@@ -3074,6 +3075,12 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool fix_op_run_order) {
self.fix_op_run_order_ = fix_op_run_order;
})
.def("_copy",
[](const BuildStrategy &self) {
auto new_bs = self;
new_bs.ClearFinalized();
return new_bs;
})
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true);
......
......@@ -3353,6 +3353,12 @@ class Block(object):
return ret_var
# NOTE(zjl): you should be careful that after you call this method,
# some Python Variable and all Python Operators should not be used
# again. Because all Python Variables and all Python Operators are
# re-constructed inside this method. The underlying VarDesc(OpDesc)
# of some old Python Variables(all old Python Operators) may have
# been destructed.
def _apply_pass(main_program,
startup_program,
pass_name,
......@@ -4286,6 +4292,14 @@ class Program(object):
self._graph = None
def _find_var_class_kwargs(self, new_desc):
# NOTE: not all variables support shape/dtype/lod_level methods.
# For example: RAW, STEP_SCOPES, etc.
def get_var_desc_attr_or_none(var_desc, attr_name, allowed_types):
if var_desc.type() in allowed_types:
return getattr(var_desc, attr_name)()
else:
return None
old_desc = self.desc
all_new_vars = []
block_num = new_desc.num_blocks()
......@@ -4302,9 +4316,21 @@ class Program(object):
kwargs = {
'type': new_var_desc.type(),
'name': new_var_desc.name(),
'shape': new_var_desc.shape(),
'dtype': new_var_desc.dtype(),
'lod_level': new_var_desc.lod_level(),
'shape': get_var_desc_attr_or_none(new_var_desc, "shape", [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]),
'dtype': get_var_desc_attr_or_none(new_var_desc, "dtype", [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]),
'lod_level':
get_var_desc_attr_or_none(new_var_desc, "lod_level", [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]),
'error_clip': old_var.error_clip
if old_var is not None else None,
'stop_gradient': old_var.stop_gradient
......@@ -4343,14 +4369,20 @@ class Program(object):
all_new_vars = self._find_var_class_kwargs(desc)
block_num = desc.num_blocks()
assert block_num == len(all_new_vars)
assert block_num == self.desc.num_blocks()
# clear old blocks and desc
self.blocks = []
self.desc = None
for idx in range(block_num):
block = self.blocks[idx]
block.vars.clear()
block.ops.clear()
for idx in range(block_num):
block_desc = self.blocks[idx].desc
new_block_desc = desc.block(idx)
block_desc._move_from(new_block_desc)
# create new blocks and set desc
self.desc = desc
self.blocks = [Block(self, idx) for idx in range(block_num)]
del desc
# add new vars first
for idx in range(block_num):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from .framework import _apply_pass
def get_data_vars(program):
data_vars = []
for var_name, var in program.global_block().vars.items():
if var.is_data:
data_vars.append(var_name)
return data_vars
def apply_build_strategy(main_program, startup_program, build_strategy,
pass_attrs):
def update_attr(attrs, attr_types, name, value, typ=None):
if name not in attrs:
attrs[name] = value
if typ:
attr_types[name] = typ
def apply_pass(name):
attrs = dict(pass_attrs)
attr_types = {}
update_attr(attrs, attr_types, "nranks", 1, "size_t")
update_attr(attrs, attr_types, "use_cuda", False, "bool")
# TODO(zjl): how to skip fetch variables ?
update_attr(attrs, attr_types, "mem_opt_skip_vars",
get_data_vars(main_program), "list[str]")
_apply_pass(main_program, startup_program, name, attrs, attr_types)
use_cuda = pass_attrs.get("use_cuda", False)
build_strategy = build_strategy._copy()
if build_strategy.sync_batch_norm:
apply_pass("sync_batch_norm_pass")
build_strategy.sync_batch_norm = False
if build_strategy.fuse_relu_depthwise_conv and use_cuda:
apply_pass("fuse_relu_depthwise_conv_pass")
build_strategy.fuse_relu_depthwise_conv = False
if build_strategy.fuse_bn_act_ops and use_cuda:
apply_pass("fuse_bn_act_pass")
build_strategy.fuse_bn_act_ops = False
if build_strategy.fuse_bn_add_act_ops and use_cuda:
apply_pass("fuse_bn_add_act_pass")
build_strategy.fuse_bn_add_act_ops = False
if build_strategy.enable_auto_fusion and use_cuda:
apply_pass("fusion_group_pass")
build_strategy.enable_auto_fusion = False
if build_strategy.fuse_elewise_add_act_ops:
apply_pass("fuse_elewise_add_act_pass")
build_strategy.fuse_elewise_add_act_ops = False
if build_strategy.fuse_all_optimizer_ops:
apply_pass("fuse_adam_op_pass")
apply_pass("fuse_sgd_op_pass")
apply_pass("fuse_momentum_op_pass")
build_strategy.fuse_all_optimizer_ops = False
# TODO(zjl): support fuse all reduce ops
if build_strategy.cache_runtime_context:
apply_pass("runtime_context_cache_pass")
build_strategy.cache_runtime_context = False
if build_strategy.enable_addto and use_cuda:
# NOTE: how to get fetch vars to skip memory optimization?
apply_pass("inplace_addto_op_pass")
build_strategy.enable_addto = False
if build_strategy.enable_inplace:
apply_pass("buffer_shared_inplace_pass")
build_strategy.enable_inplace = False
build_strategy._clear_finalized()
return build_strategy
......@@ -16,20 +16,16 @@ import paddle
from paddle.vision.models import resnet50
from paddle.nn import CrossEntropyLoss
from paddle.fluid.framework import _apply_pass
from paddle.fluid.ir import apply_build_strategy
import paddle.fluid as fluid
import unittest
import numpy as np
class TestApplyPassToProgram(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def global_block_contains_op(self, program, op_type):
for op in program.global_block().ops:
if op.type == op_type:
return True
return False
def test_case(self):
def get_resnet50_model():
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=[None, 3, 224, 224], dtype="float32")
label = paddle.static.data(name="label", shape=[None, 1], dtype="int64")
......@@ -37,14 +33,27 @@ class TestApplyPassToProgram(unittest.TestCase):
loss_fn = CrossEntropyLoss()
pred = model(image)
loss = loss_fn(pred, label)
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(loss)
startup = paddle.static.default_startup_program()
main = paddle.static.default_main_program()
return main, startup, image, label, loss
def global_block_contains_op(program, op_type):
for op in program.global_block().ops:
if op.type == op_type:
return True
return False
class TestApplyPassToProgram(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def test_case(self):
main, startup, image, label, loss = get_resnet50_model()
fused_op = "fused_elemwise_add_activation"
self.assertFalse(self.global_block_contains_op(main, fused_op))
self.assertFalse(global_block_contains_op(main, fused_op))
attrs = {
"int_attr": -3,
"size_t_attr": 10,
......@@ -59,7 +68,135 @@ class TestApplyPassToProgram(unittest.TestCase):
ret_attrs = _apply_pass(main, startup, "fuse_elewise_add_act_pass",
attrs, attr_types)
self.assertEqual(attrs, ret_attrs)
self.assertTrue(self.global_block_contains_op(main, fused_op))
self.assertTrue(global_block_contains_op(main, fused_op))
class TestIRPassBase(unittest.TestCase):
def setUp(self):
paddle.enable_static()
if paddle.is_compiled_with_cuda():
fluid.set_flags({
'FLAGS_cudnn_deterministic': 1,
'FLAGS_max_inplace_grad_add': 6,
})
self.place = paddle.CUDAPlace(0)
else:
self.place = paddle.CPUPlace()
self.use_cuda = isinstance(self.place, paddle.CUDAPlace)
self.executor = paddle.static.Executor(self.place)
self.num_classes = 1000
self.seed = 1
def get_strategy(self):
return {
'enable_inplace': True,
'enable_addto': True,
'fuse_all_optimizer_ops': True,
'fuse_elewise_add_act_ops': True,
'fuse_relu_depthwise_conv': True,
'fuse_bn_act_ops': True,
}
def check_before_applied(self, main, startup):
self.assertFalse(global_block_contains_op(main, "share_buffer"))
self.assertFalse(global_block_contains_op(main, "coalesce_tensor"))
self.assertFalse(
global_block_contains_op(main, "fused_elemwise_add_activation"))
adam_cnt = 0
for op in main.global_block().ops:
if op.type == "adam":
adam_cnt += 1
self.assertGreater(adam_cnt, 1)
def check_after_applied(self, main, startup):
self.assertTrue(global_block_contains_op(main, "share_buffer"))
# fused all optimizer pass requires this
if paddle.is_compiled_with_cuda():
self.assertTrue(global_block_contains_op(main, "coalesce_tensor"))
self.assertTrue(
global_block_contains_op(main, "fused_elemwise_add_activation"))
share_dims_cnt = 0
non_share_dims_cnt = 0
for op in main.global_block().ops:
if op.type != "share_buffer":
continue
share_dims = op.attr("share_dims")
if share_dims:
for i in range(len(share_dims)):
self.assertEqual(share_dims[0], share_dims[i])
if share_dims[0] is True:
share_dims_cnt += 1
else:
non_share_dims_cnt += 1
else:
non_share_dims_cnt += 1
if self.use_cuda:
self.assertGreaterEqual(share_dims_cnt, 1)
else:
self.assertEqual(share_dims_cnt, 0)
self.assertGreaterEqual(non_share_dims_cnt, 1)
if paddle.is_compiled_with_cuda():
adam_cnt = 0
for op in main.global_block().ops:
if op.type == "adam":
adam_cnt += 1
self.assertEqual(adam_cnt, 1)
def test_main(self):
if self.use_cuda:
batch_num = 20
batch_size = 4
else:
batch_num = 3
batch_size = 2
paddle.seed(self.seed)
main1, startup1, image, label, loss1 = get_resnet50_model()
main2, startup2, image, label, loss2 = get_resnet50_model()
build_strategy = paddle.static.BuildStrategy()
for k, v in self.get_strategy().items():
setattr(build_strategy, k, v)
self.check_before_applied(main2, startup2)
apply_build_strategy(main2, startup2, build_strategy,
{"use_cuda": self.use_cuda})
self.check_after_applied(main2, startup2)
image_shape = [batch_size] + list(image.shape)[1:]
label_shape = [batch_size] + list(label.shape)[1:]
paddle.seed(self.seed)
scope1 = paddle.static.Scope()
with paddle.static.scope_guard(scope1):
self.executor.run(startup1)
paddle.seed(self.seed)
scope2 = paddle.static.Scope()
with paddle.static.scope_guard(scope2):
self.executor.run(startup2)
for idx in range(batch_num):
feed = {
image.name: np.random.rand(*image_shape).astype('float32'),
label.name: np.random.randint(
low=0,
high=self.num_classes,
size=label_shape,
dtype='int64'),
}
with paddle.static.scope_guard(scope1):
loss_value1 = self.executor.run(main1,
feed=feed,
fetch_list=[loss1])[0]
with paddle.static.scope_guard(scope2):
loss_value2 = self.executor.run(main2,
feed=feed,
fetch_list=[loss2])[0]
self.assertEqual(loss_value1, loss_value2, "batch {}".format(idx))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册