未验证 提交 2fd8deea 编写于 作者: W wuhuanzhou 提交者: GitHub

C++ support register pass via PassDesc (#36095)

支持C++开发注册GeneratePass,简化针对fusion等子图优化场景开发方式。
上级 d8887afa
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -224,6 +225,115 @@ bool GeneratePass::VerifyGraph(const Graph& graph) {
return true;
}
namespace generate_pass {
VarHelper::VarHelper(const char* name) : name_(name), type_(Type::kInput) {}
VarHelper::VarHelper(const std::string& name, Type type)
: name_(name), type_(type) {}
OpHelper::OpHelper(const char* type, SubgraphHelper* subgraph_helper)
: type_(type), subgraph_helper_(subgraph_helper) {
op_desc_ = subgraph_helper_->ProgramDesc()->mutable_blocks(0)->add_ops();
op_desc_->set_type(type_);
}
OpHelper::Arguments::Arguments(const char* parameter,
const VarHelper& var_helper)
: parameter_(parameter) {
var_helpers_.push_back(var_helper);
}
OpHelper::Arguments::Arguments(const char* parameter,
std::initializer_list<VarHelper> var_helpers)
: parameter_(parameter), var_helpers_(var_helpers) {}
OpHelper& OpHelper::operator()(const Arguments& input) {
proto::OpDesc::Var* var = op_desc_->add_inputs();
var->set_parameter(input.parameter_);
for (const VarHelper& var_helper : input.var_helpers_) {
var->add_arguments()->assign(var_helper.name_);
if (VarHelper::Type::kInput == var_helper.type_) {
subgraph_helper_->AddInputVar(var_helper.name_);
}
}
return *this;
}
OpHelper& OpHelper::operator()(std::initializer_list<Arguments> inputs) {
for (const auto& input : inputs) {
operator()(input);
}
return *this;
}
VarHelper OpHelper::Out(const char* name) {
std::string argument = patterns::UniqueKey(type_);
proto::OpDesc::Var* var = op_desc_->add_outputs();
var->set_parameter(name);
var->add_arguments()->assign(argument);
return VarHelper(argument, VarHelper::Type::kOutput);
}
proto::ProgramDesc* SubgraphHelper::ProgramDesc() { return &program_desc_; }
const proto::ProgramDesc& SubgraphHelper::ProgramDesc() const {
return program_desc_;
}
const std::vector<std::string>& SubgraphHelper::InputVars() const {
return input_vars_;
}
const std::vector<std::string>& SubgraphHelper::OutputVars() const {
return output_vars_;
}
void SubgraphHelper::AddInputVar(const std::string& name) {
auto iter = std::find(input_vars_.begin(), input_vars_.end(), name);
if (input_vars_.end() == iter) {
input_vars_.push_back(name);
}
}
void SubgraphHelper::AddOutputVars(const VarHelper& var_helper) {
output_vars_.push_back(var_helper.name_);
}
} // namespace generate_pass
PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) {
AddPassDesc(pattern, replace);
}
void PassPairs::AddPassDesc(const SubgraphType& pattern,
const SubgraphType& replace) {
proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs();
pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc());
pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc());
PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(),
platform::errors::InvalidArgument(
"Size of lambda expression arguments is not equal "
"between pattern/replace subgraph."));
for (size_t i = 0; i < pattern.InputVars().size(); i++) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(pattern.InputVars()[i]);
var_map->set_replace_var(replace.InputVars()[i]);
}
PADDLE_ENFORCE_EQ(pattern.OutputVars().size(), replace.OutputVars().size(),
platform::errors::InvalidArgument(
"Size of lambda expression returns is not equal "
"between pattern/replace subgraph."));
for (size_t i = 0; i < pattern.OutputVars().size(); i++) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(pattern.OutputVars()[i]);
var_map->set_replace_var(replace.OutputVars()[i]);
}
}
const proto::MultiPassDesc& PassPairs::MultiPassDesc() const {
return multi_pass_desc_;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/pass_desc.pb.h"
......@@ -43,6 +42,158 @@ class GeneratePass : public Pass {
proto::MultiPassDesc multi_pass_desc_;
};
namespace generate_pass {
class VarHelper;
class OpHelper;
class SubgraphHelper;
// VarHelper is used to represent a variable node.
struct VarHelper {
enum class Type { kInput, kOutput };
explicit VarHelper(const char* name);
VarHelper(const std::string& name, Type type);
std::string name_;
Type type_;
};
// OpHelper is used to represent a operator node.
class OpHelper {
public:
// Convert multiple inputs.
struct Arguments {
Arguments(const char* parameter, const VarHelper& var_helper);
Arguments(const char* parameter,
std::initializer_list<VarHelper> var_helpers);
std::string parameter_;
std::vector<VarHelper> var_helpers_;
};
OpHelper(const char* type, SubgraphHelper* subgraph_helper);
OpHelper& operator()(const Arguments& input);
OpHelper& operator()(std::initializer_list<Arguments> inputs);
VarHelper Out(const char* name);
private:
OpHelper() = delete;
DISABLE_COPY_AND_ASSIGN(OpHelper);
const char* type_;
proto::OpDesc* op_desc_;
SubgraphHelper* subgraph_helper_;
};
/*
* SubgraphHelper is used to define pattern/replace subgraphs.
*
* Use lambda expression to define subgraph like Python. SubgraphHelper
* converts lambda expression to ProgramDesc.
*
* In order to define a subgraph, user need to use VarHelper and OpHelper.
* Use the macros instead of class names, so user can develop better and
* don't need to know too much about underlying implementation.
*
* An example of defining a subgraph as follows:
*
* SUBGRAPH_(subgraph)([subgraph=&subgraph](VAR_(x), VAR_(y), VAR_(z)) {
* auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
* auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
* return ewadd2;
* });
*
*/
class SubgraphHelper {
public:
SubgraphHelper() = default;
// The lambda expression is a prvalue expression.
template <typename T>
SubgraphHelper& operator=(const T&& f) {
proto::BlockDesc* block = program_desc_.add_blocks();
block->set_idx(0);
block->set_parent_idx(0);
AddOutputVars(f());
return *this;
}
proto::ProgramDesc* ProgramDesc();
const proto::ProgramDesc& ProgramDesc() const;
const std::vector<std::string>& InputVars() const;
const std::vector<std::string>& OutputVars() const;
void AddInputVar(const std::string& name);
void AddOutputVars(const VarHelper& var_helper);
template <size_t i, typename... Ts,
std::enable_if_t<i + 1 < sizeof...(Ts)>* = nullptr>
void AddOutputVars(const std::tuple<Ts...>& outputs) {
AddOutputVars(std::get<i>(outputs));
AddOutputVars<i + 1>(outputs);
}
template <size_t i, typename... Ts,
std::enable_if_t<i + 1 == sizeof...(Ts)>* = nullptr>
void AddOutputVars(const std::tuple<Ts...>& outputs) {
AddOutputVars(std::get<i>(outputs));
}
template <typename... Ts>
void AddOutputVars(const std::tuple<Ts...>& outputs) {
AddOutputVars<0>(outputs);
}
private:
DISABLE_COPY_AND_ASSIGN(SubgraphHelper);
std::vector<std::string> input_vars_;
std::vector<std::string> output_vars_;
proto::ProgramDesc program_desc_;
};
} // namespace generate_pass
class PassPairs {
public:
using SubgraphType = generate_pass::SubgraphHelper;
PassPairs() = default;
PassPairs(const SubgraphType& pattern, const SubgraphType& replace);
void AddPassDesc(const SubgraphType& pattern, const SubgraphType& replace);
const proto::MultiPassDesc& MultiPassDesc() const;
private:
proto::MultiPassDesc multi_pass_desc_;
};
// Use function to register in CC.
template <PassPairs (*Functor)(void)>
class MacroPassHelper : public GeneratePass {
public:
MacroPassHelper() : GeneratePass(Functor().MultiPassDesc()) {}
};
#define VAR_(name) \
::paddle::framework::ir::generate_pass::VarHelper name = \
::paddle::framework::ir::generate_pass::VarHelper(#name)
#define OP_(type) \
::paddle::framework::ir::generate_pass::OpHelper(#type, subgraph)
#define SUBGRAPH_(name) \
::paddle::framework::ir::generate_pass::SubgraphHelper name; \
name
#define REGISTER_GENERATE_PASS(pass_type) \
paddle::framework::ir::PassPairs register_##pass_type(); \
REGISTER_PASS( \
pass_type, \
::paddle::framework::ir::MacroPassHelper<&register_##pass_type>); \
paddle::framework::ir::PassPairs register_##pass_type()
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -16,234 +16,71 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
template <proto::MultiPassDesc (*Functor)(void)>
class CXXGeneratePass : public GeneratePass {
public:
CXXGeneratePass() : GeneratePass(Functor()) {}
};
#define REGISTER_GENERATE_PASS(pass_type, function) \
REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>)
proto::MultiPassDesc generate_fc_fuse() {
proto::MultiPassDesc multi_pass_desc;
REGISTER_GENERATE_PASS(generate_fc_fuse) {
paddle::framework::ir::PassPairs pass_pairs;
for (bool with_relu : {true, false}) {
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
pattern->set_idx(0);
pattern->set_parent_idx(0);
proto::OpDesc* mul = pattern->add_ops();
mul->set_type("mul");
proto::OpDesc::Var* mul_x = mul->add_inputs();
mul_x->set_parameter("X");
mul_x->add_arguments()->assign("x");
proto::OpDesc::Var* mul_y = mul->add_inputs();
mul_y->set_parameter("Y");
mul_y->add_arguments()->assign("w");
proto::OpDesc::Var* mul_out = mul->add_outputs();
mul_out->set_parameter("Out");
mul_out->add_arguments()->assign("mul_out");
proto::OpDesc* ewadd = pattern->add_ops();
ewadd->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_x = ewadd->add_inputs();
ewadd_x->set_parameter("X");
ewadd_x->add_arguments()->assign("mul_out");
proto::OpDesc::Var* ewadd_y = ewadd->add_inputs();
ewadd_y->set_parameter("Y");
ewadd_y->add_arguments()->assign("b");
proto::OpDesc::Var* ewadd_out = ewadd->add_outputs();
ewadd_out->set_parameter("Out");
ewadd_out->add_arguments()->assign("ewadd_out");
proto::OpDesc* relu = nullptr;
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
replace->set_idx(0);
replace->set_parent_idx(0);
proto::OpDesc* fc = replace->add_ops();
fc->set_type("fc");
proto::OpDesc::Var* fc_x = fc->add_inputs();
fc_x->set_parameter("Input");
fc_x->add_arguments()->assign("x");
proto::OpDesc::Var* fc_w = fc->add_inputs();
fc_w->set_parameter("W");
fc_w->add_arguments()->assign("w");
proto::OpDesc::Var* fc_b = fc->add_inputs();
fc_b->set_parameter("Bias");
fc_b->add_arguments()->assign("b");
proto::OpDesc::Var* fc_out = fc->add_outputs();
fc_out->set_parameter("Out");
fc_out->add_arguments()->assign("fc_out");
for (const char* var : {"x", "w", "b", "fc_out"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
proto::PassDesc::AttrMap* attr_map = pass_desc->add_attr_maps();
attr_map->set_pattern_op_idx(0);
attr_map->set_pattern_name("x_num_col_dims");
attr_map->set_replace_op_idx(0);
attr_map->set_replace_name("in_num_col_dims");
if (with_relu) {
relu = pattern->add_ops();
relu->set_type("relu");
proto::OpDesc::Var* relu_x = relu->add_inputs();
relu_x->set_parameter("X");
relu_x->add_arguments()->assign("ewadd_out");
proto::OpDesc::Var* relu_out = relu->add_outputs();
relu_out->set_parameter("Out");
relu_out->add_arguments()->assign("relu_out");
pass_desc->mutable_var_maps(3)->set_pattern_var("relu_out");
proto::OpDesc::Attr* attr = fc->add_attrs();
attr->set_name("activation_type");
attr->set_type(proto::AttrType::STRING);
attr->set_s("relu");
} else {
pass_desc->mutable_var_maps(3)->set_pattern_var("ewadd_out");
}
// pattern
SUBGRAPH_(pattern) =
[ subgraph = &pattern, with_relu ](VAR_(x), VAR_(y), VAR_(z)) {
VLOG(3) << "exec lambda func.";
auto mul = OP_(mul)({{"X", x}, {"Y", y}}).Out("Out");
auto ewadd = OP_(elementwise_add)({{"X", mul}, {"Y", z}}).Out("Out");
if (with_relu) {
return OP_(relu)({"X", ewadd}).Out("Out");
} else {
return ewadd;
}
};
// replace
SUBGRAPH_(replace) =
[ subgraph = &replace, with_relu ](VAR_(x), VAR_(y), VAR_(z)) {
auto& fc = OP_(fc)({{"Input", x}, {"W", y}, {"Bias", z}});
return fc.Out("Out");
};
pass_pairs.AddPassDesc(pattern, replace);
}
return multi_pass_desc;
return pass_pairs;
}
proto::MultiPassDesc generate_multi_add_to_addn() {
proto::MultiPassDesc multi_pass_desc;
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
proto::OpDesc* ewadd_0 = pattern->add_ops();
ewadd_0->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_0_x = ewadd_0->add_inputs();
ewadd_0_x->set_parameter("X");
ewadd_0_x->add_arguments()->assign("a");
proto::OpDesc::Var* ewadd_0_y = ewadd_0->add_inputs();
ewadd_0_y->set_parameter("Y");
ewadd_0_y->add_arguments()->assign("b");
proto::OpDesc::Var* ewadd_0_out = ewadd_0->add_outputs();
ewadd_0_out->set_parameter("Out");
ewadd_0_out->add_arguments()->assign("ewadd_out_0");
proto::OpDesc* ewadd_1 = pattern->add_ops();
ewadd_1->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_1_x = ewadd_1->add_inputs();
ewadd_1_x->set_parameter("X");
ewadd_1_x->add_arguments()->assign("ewadd_out_0");
proto::OpDesc::Var* ewadd_1_y = ewadd_1->add_inputs();
ewadd_1_y->set_parameter("Y");
ewadd_1_y->add_arguments()->assign("c");
proto::OpDesc::Var* ewadd_1_out = ewadd_1->add_outputs();
ewadd_1_out->set_parameter("Out");
ewadd_1_out->add_arguments()->assign("ewadd_out_1");
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
proto::OpDesc* addn = replace->add_ops();
addn->set_type("add_n");
proto::OpDesc::Var* addn_x = addn->add_inputs();
addn_x->set_parameter("X");
addn_x->add_arguments()->assign("a");
addn_x->add_arguments()->assign("b");
addn_x->add_arguments()->assign("c");
proto::OpDesc::Var* addn_out = addn->add_outputs();
addn_out->set_parameter("Out");
addn_out->add_arguments()->assign("addn_out");
for (const char* var : {"a", "b", "c", "ewadd_out_1"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
pass_desc->mutable_var_maps(3)->set_replace_var("addn_out");
return multi_pass_desc;
REGISTER_GENERATE_PASS(generate_multi_add_to_addn) {
// pattern
SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) {
auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
return ewadd2;
};
// replace
SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) {
return OP_(sum)({"X", {x, y, z}}).Out("Out");
};
return {pattern, replace};
}
proto::MultiPassDesc generate_combine_matmul() {
proto::MultiPassDesc multi_pass_desc;
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
proto::OpDesc* matmul_0 = pattern->add_ops();
matmul_0->set_type("matmul");
proto::OpDesc::Var* matmul_0_x = matmul_0->add_inputs();
matmul_0_x->set_parameter("X");
matmul_0_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_0_y = matmul_0->add_inputs();
matmul_0_y->set_parameter("Y");
matmul_0_y->add_arguments()->assign("b");
proto::OpDesc::Var* matmul_0_out = matmul_0->add_outputs();
matmul_0_out->set_parameter("Out");
matmul_0_out->add_arguments()->assign("matmul_out_0");
proto::OpDesc* matmul_1 = pattern->add_ops();
matmul_1->set_type("matmul");
proto::OpDesc::Var* matmul_1_x = matmul_1->add_inputs();
matmul_1_x->set_parameter("X");
matmul_1_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_1_y = matmul_1->add_inputs();
matmul_1_y->set_parameter("Y");
matmul_1_y->add_arguments()->assign("c");
proto::OpDesc::Var* matmul_1_out = matmul_1->add_outputs();
matmul_1_out->set_parameter("Out");
matmul_1_out->add_arguments()->assign("matmul_out_1");
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
proto::OpDesc* concat = replace->add_ops();
concat->set_type("concat");
proto::OpDesc::Var* concat_x = concat->add_inputs();
concat_x->set_parameter("X");
concat_x->add_arguments()->assign("b");
concat_x->add_arguments()->assign("c");
proto::OpDesc::Var* concat_out = concat->add_outputs();
concat_out->set_parameter("Out");
concat_out->add_arguments()->assign("concat_out");
proto::OpDesc* matmul = replace->add_ops();
matmul->set_type("matmul");
proto::OpDesc::Var* matmul_x = matmul->add_inputs();
matmul_x->set_parameter("X");
matmul_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_y = matmul->add_inputs();
matmul_y->set_parameter("Y");
matmul_y->add_arguments()->assign("concat_out");
proto::OpDesc::Var* matmul_out = matmul->add_outputs();
matmul_out->set_parameter("Out");
matmul_out->add_arguments()->assign("matmul_out");
proto::OpDesc* slice_0 = replace->add_ops();
slice_0->set_type("slice");
proto::OpDesc::Var* slice_0_x = slice_0->add_inputs();
slice_0_x->set_parameter("X");
slice_0_x->add_arguments()->assign("matmul_out");
proto::OpDesc::Var* slice_0_out = slice_0->add_outputs();
slice_0_out->set_parameter("Out");
slice_0_out->add_arguments()->assign("slice_out_0");
proto::OpDesc* slice_1 = replace->add_ops();
slice_1->set_type("slice");
proto::OpDesc::Var* slice_1_x = slice_1->add_inputs();
slice_1_x->set_parameter("X");
slice_1_x->add_arguments()->assign("matmul_out");
proto::OpDesc::Var* slice_1_out = slice_1->add_outputs();
slice_1_out->set_parameter("Out");
slice_1_out->add_arguments()->assign("slice_out_1");
for (const char* var : {"a", "b", "c", "matmul_out_0", "matmul_out_1"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
pass_desc->mutable_var_maps(3)->set_replace_var("slice_out_0");
pass_desc->mutable_var_maps(4)->set_replace_var("slice_out_1");
return multi_pass_desc;
REGISTER_GENERATE_PASS(generate_combine_matmul) {
// pattern
SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) {
auto matmul1 = OP_(matmul)({{"X", x}, {"Y", y}}).Out("Out");
auto matmul2 = OP_(matmul)({{"X", x}, {"Y", z}}).Out("Out");
return std::make_tuple(matmul1, matmul2);
};
// replace
SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) {
auto concat = OP_(concat)({"X", {y, z}}).Out("Out");
auto matmul = OP_(matmul)({{"X", x}, {"Y", concat}}).Out("Out");
auto slice1 = OP_(slice)({"X", matmul}).Out("Out");
auto slice2 = OP_(slice)({"X", matmul}).Out("Out");
return std::make_tuple(slice1, slice2);
};
return {pattern, replace};
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_GENERATE_PASS(generate_fc_fuse,
paddle::framework::ir::generate_fc_fuse);
REGISTER_GENERATE_PASS(generate_multi_add_to_addn,
paddle::framework::ir::generate_multi_add_to_addn);
REGISTER_GENERATE_PASS(generate_combine_matmul,
paddle::framework::ir::generate_combine_matmul);
namespace paddle {
namespace framework {
namespace ir {
TEST(GeneratePass, construct_with_string) {
std::string binary_str;
generate_fc_fuse().SerializeToString(&binary_str);
register_generate_fc_fuse().MultiPassDesc().SerializeToString(&binary_str);
GeneratePass generate_pass(binary_str);
}
......@@ -318,7 +155,7 @@ TEST(GeneratePass, generate_multi_add_to_addn) {
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_addn_nodes_after = GetNumOpNodes(graph, "add_n");
int num_addn_nodes_after = GetNumOpNodes(graph, "sum");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册