“8b5431d5e1e4f54339594ba6a76ecbccf256080f”上不存在“paddle/legacy/function/SwitchOpTest.cpp”
未验证 提交 463a4f25 编写于 作者: M ming1753 提交者: GitHub

[IR&PASS] add conv + elementwise_add fuse pattern (#55176)

* [IR&PASS] add conv + elementwise_add fuse pattern

* add conv2dAddPattern to pass
上级 88b986a6
...@@ -49,6 +49,13 @@ ...@@ -49,6 +49,13 @@
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
// build Conv2dFusionOp
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/infermeta/multiary.h"
// Define op1. // Define op1.
class Operation1 : public ir::Op<Operation1> { class Operation1 : public ir::Op<Operation1> {
public: public:
...@@ -329,6 +336,623 @@ class Conv2dBnFusePattern ...@@ -329,6 +336,623 @@ class Conv2dBnFusePattern
} }
}; };
namespace paddle {
namespace dialect {
class Conv2dFusionOpTest : public ir::Op<Conv2dFusionOpTest,
OpYamlInfoInterface,
InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd.conv2d_fusion_test"; }
static const char *attributes_name[10];
static constexpr uint32_t attributes_num = 10;
static OpInfoTuple GetOpInfo();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult input_,
ir::OpResult filter_,
ir::OpResult bias_,
ir::OpResult residual_,
const std::vector<int> &strides,
const std::vector<int> &paddings_t,
std::string padding_algorithm,
const std::vector<int> &dilations_t,
int groups,
std::string data_format,
std::string activation,
bool exhaustive_search,
const std::vector<int> &channels,
int user_workspace_size);
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult input_,
ir::OpResult filter_,
ir::OpResult bias_,
ir::OpResult residual_,
ir::AttributeMap attributes);
void Verify();
ir::Value input() { return operand(0); }
ir::Value filter() { return operand(1); }
ir::Value bias() { return operand(2); }
ir::Value residual() { return operand(3); }
ir::OpResult output() { return result(0); }
ir::OpResult outputs() { return result(1); }
ir::Attribute attribute(const std::string &name) {
{
PADDLE_ENFORCE(
attributes().count(name) > 0,
phi::errors::PreconditionNotMet("Attribute is not exist."));
return attributes().at(name);
}
}
template <typename T>
T attribute(const std::string &name) {
{
PADDLE_ENFORCE(
attributes().count(name) > 0 && attributes().at(name).isa<T>(),
phi::errors::PreconditionNotMet("Attribute is not right."));
return attributes().at(name).dyn_cast<T>();
}
}
static void InferMeta(phi::InferMetaContext *infer_meta);
};
const char *Conv2dFusionOpTest::attributes_name[10] = {"strides",
"paddings_t",
"padding_algorithm",
"dilations_t",
"groups",
"data_format",
"activation",
"exhaustive_search",
"channels",
"user_workspace_size"};
OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo(
"input", "paddle::dialect::DenseTensorType", false, false, false),
OpInputInfo(
"filter", "paddle::dialect::DenseTensorType", false, false, false),
OpInputInfo(
"bias", "paddle::dialect::DenseTensorType", false, false, false),
OpInputInfo(
"residual", "paddle::dialect::DenseTensorType", true, false, false)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
OpAttributeInfo("strides", "ir::ArrayAttribute<ir::Int32Attribute>", ""),
OpAttributeInfo(
"paddings_t", "ir::ArrayAttribute<ir::Int32Attribute>", ""),
OpAttributeInfo("padding_algorithm", "ir::StrAttribute", ""),
OpAttributeInfo(
"dilations_t", "ir::ArrayAttribute<ir::Int32Attribute>", ""),
OpAttributeInfo("groups", "ir::Int32Attribute", ""),
OpAttributeInfo("data_format", "ir::StrAttribute", ""),
OpAttributeInfo("activation", "ir::StrAttribute", ""),
OpAttributeInfo("exhaustive_search", "ir::BoolAttribute", ""),
OpAttributeInfo("channels", "ir::ArrayAttribute<ir::Int32Attribute>", ""),
OpAttributeInfo("user_workspace_size", "ir::Int32Attribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("output", "paddle::dialect::DenseTensorType", false, false),
OpOutputInfo("outputs",
"ir::VectorType<paddle::dialect::DenseTensorType>",
false,
false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("Conv2dFusionInferMeta",
{"input",
"filter",
"bias",
"residual",
"strides",
"paddings_t",
"padding_algorithm",
"dilations_t",
"groups",
"data_format",
"activation",
"exhaustive_search",
"channels",
"user_workspace_size"},
{"ConvFusionKernel"},
{"input",
"filter",
"bias",
"residual",
"strides",
"paddings_t",
"padding_algorithm",
"dilations_t",
"groups",
"data_format",
"activation",
"exhaustive_search",
"channels",
"user_workspace_size"},
{"input"},
{},
{});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
}
void Conv2dFusionOpTest::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult input_,
ir::OpResult filter_,
ir::OpResult bias_,
ir::OpResult residual_,
ir::AttributeMap attributes) {
std::vector<int> strides;
for (size_t i = 0;
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) {
strides.push_back(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::Int32Attribute>()
.data());
}
std::vector<int> paddings_t;
for (size_t i = 0;
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
paddings_t.push_back(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::Int32Attribute>()
.data());
}
std::string padding_algorithm =
attributes.at("padding_algorithm").dyn_cast<ir::StrAttribute>().data();
std::vector<int> dilations_t;
for (size_t i = 0;
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
dilations_t.push_back(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::Int32Attribute>()
.data());
}
int groups = attributes.at("groups").dyn_cast<ir::Int32Attribute>().data();
std::string data_format =
attributes.at("data_format").dyn_cast<ir::StrAttribute>().data();
std::string activation =
attributes.at("activation").dyn_cast<ir::StrAttribute>().data();
bool exhaustive_search =
attributes.at("exhaustive_search").dyn_cast<ir::BoolAttribute>().data();
std::vector<int> channels;
for (size_t i = 0;
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) {
channels.push_back(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::Int32Attribute>()
.data());
}
int user_workspace_size = attributes.at("user_workspace_size")
.dyn_cast<ir::Int32Attribute>()
.data();
VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {
input_, filter_, bias_, residual_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
VLOG(4) << "Builder construction attributes";
std::vector<ir::Attribute> vec_strides;
for (size_t i = 0; i < static_cast<size_t>(strides.size()); i++) {
ir::Attribute attr_strides =
ir::Int32Attribute::get(ir::IrContext::Instance(), strides[i]);
vec_strides.push_back(attr_strides);
}
ir::Attribute attr_strides =
ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_strides);
argument.AddAttribute("strides", attr_strides);
std::vector<ir::Attribute> vec_paddings_t;
for (size_t i = 0; i < static_cast<size_t>(paddings_t.size()); i++) {
ir::Attribute attr_paddings_t =
ir::Int32Attribute::get(ir::IrContext::Instance(), paddings_t[i]);
vec_paddings_t.push_back(attr_paddings_t);
}
ir::Attribute attr_paddings_t =
ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_paddings_t);
argument.AddAttribute("paddings_t", attr_paddings_t);
ir::Attribute attr_padding_algorithm =
ir::StrAttribute::get(ir::IrContext::Instance(), padding_algorithm);
argument.AddAttribute("padding_algorithm", attr_padding_algorithm);
std::vector<ir::Attribute> vec_dilations_t;
for (size_t i = 0; i < static_cast<size_t>(dilations_t.size()); i++) {
ir::Attribute attr_dilations_t =
ir::Int32Attribute::get(ir::IrContext::Instance(), dilations_t[i]);
vec_dilations_t.push_back(attr_dilations_t);
}
ir::Attribute attr_dilations_t =
ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_dilations_t);
argument.AddAttribute("dilations_t", attr_dilations_t);
ir::Attribute attr_groups =
ir::Int32Attribute::get(ir::IrContext::Instance(), groups);
argument.AddAttribute("groups", attr_groups);
ir::Attribute attr_data_format =
ir::StrAttribute::get(ir::IrContext::Instance(), data_format);
argument.AddAttribute("data_format", attr_data_format);
ir::Attribute attr_activation =
ir::StrAttribute::get(ir::IrContext::Instance(), activation);
argument.AddAttribute("activation", attr_activation);
ir::Attribute attr_exhaustive_search =
ir::BoolAttribute::get(ir::IrContext::Instance(), exhaustive_search);
argument.AddAttribute("exhaustive_search", attr_exhaustive_search);
std::vector<ir::Attribute> vec_channels;
for (size_t i = 0; i < static_cast<size_t>(channels.size()); i++) {
ir::Attribute attr_channels =
ir::Int32Attribute::get(ir::IrContext::Instance(), channels[i]);
vec_channels.push_back(attr_channels);
}
ir::Attribute attr_channels =
ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_channels);
argument.AddAttribute("channels", attr_channels);
ir::Attribute attr_user_workspace_size =
ir::Int32Attribute::get(ir::IrContext::Instance(), user_workspace_size);
argument.AddAttribute("user_workspace_size", attr_user_workspace_size);
VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType input =
input_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)input;
paddle::dialect::DenseTensorType filter =
filter_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)filter;
paddle::dialect::DenseTensorType bias =
bias_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)bias;
// paddle::dialect::DenseTensorType residual =
// residual_.type().dyn_cast<paddle::dialect::DenseTensorType>();
// (void)residual;
VLOG(4) << "Builder construction dense_input";
phi::DenseTensor dense_input(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(TransToPhiDataType(input.dtype()),
input.dims(),
input.data_layout(),
input.lod(),
input.offset()));
VLOG(4) << "Builder construction meta_input";
phi::MetaTensor meta_input(&dense_input);
VLOG(4) << "Builder construction dense_filter";
phi::DenseTensor dense_filter(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(TransToPhiDataType(filter.dtype()),
filter.dims(),
filter.data_layout(),
filter.lod(),
filter.offset()));
VLOG(4) << "Builder construction meta_filter";
phi::MetaTensor meta_filter(&dense_filter);
VLOG(4) << "Builder construction dense_bias";
phi::DenseTensor dense_bias(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(TransToPhiDataType(bias.dtype()),
bias.dims(),
bias.data_layout(),
bias.lod(),
bias.offset()));
VLOG(4) << "Builder construction meta_bias";
phi::MetaTensor meta_bias(&dense_bias);
// VLOG(4) << "Builder construction dense_residual";
// phi::DenseTensor
// dense_residual(std::make_unique<paddle::experimental::DefaultAllocator>(paddle::platform::CPUPlace()).get(),
// phi::DenseTensorMeta(TransToPhiDataType(residual.dtype()),
// residual.dims(),
// residual.data_layout(),
// residual.lod(),
// residual.offset()));
VLOG(4) << "Builder construction meta_residual";
// phi::MetaTensor meta_residual(&dense_residual);
phi::MetaTensor meta_residual;
phi::DenseTensor dense_output;
phi::MetaTensor meta_output(&dense_output);
std::vector<phi::DenseTensor> vec_dense_outputs((channels.size()),
phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_outputs;
for (size_t i = 0; i < static_cast<size_t>(channels.size()); i++) {
vec_meta_outputs.push_back(phi::MetaTensor(&vec_dense_outputs[i]));
}
std::vector<phi::MetaTensor *> meta_outputs;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_outputs.size()); i++) {
meta_outputs.push_back(&vec_meta_outputs[i]);
}
phi::FusedConvInferMeta(meta_input,
meta_filter,
meta_bias,
meta_residual,
strides,
paddings_t,
padding_algorithm,
dilations_t,
groups,
data_format,
"float32",
"identity",
false,
false,
&meta_output,
phi::MetaConfig());
std::vector<ir::Type> argument_outputs;
auto output_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
TransToIrDataType(dense_output.dtype()),
dense_output.dims(),
dense_output.layout(),
dense_output.lod(),
dense_output.offset());
LOG(INFO) << output_dense_tensor_type;
argument_outputs.push_back(output_dense_tensor_type);
std::vector<ir::Type> outputs_types;
for (size_t i = 0; i < static_cast<size_t>(channels.size()); i++) {
outputs_types.push_back(paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
TransToIrDataType(vec_dense_outputs[i].dtype()),
vec_dense_outputs[i].dims(),
vec_dense_outputs[i].layout(),
vec_dense_outputs[i].lod(),
vec_dense_outputs[i].offset()));
}
ir::Type outputs_vector_type =
ir::VectorType::get(ir::IrContext::Instance(), outputs_types);
argument_outputs.push_back(outputs_vector_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}
void Conv2dFusionOpTest::Verify() {
VLOG(4)
<< "Start Verifying inputs, outputs and attributes for: Conv2dFusionOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
4u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 4.", input_size));
PADDLE_ENFORCE(
(*this)->operand(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE(
(*this)->operand(1).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
PADDLE_ENFORCE(
(*this)->operand(2).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input."));
if (auto val = (*this)->op_operand(3)) {
PADDLE_ENFORCE(val.type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 3th input."));
}
}
VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("strides") > 0 &&
attributes.at("strides").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: strides is not right."));
for (size_t i = 0;
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i]
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: strides is not right."));
}
PADDLE_ENFORCE(attributes.count("paddings_t") > 0 &&
attributes.at("paddings_t").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: paddings_t is not right."));
for (size_t i = 0;
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: paddings_t is not right."));
}
PADDLE_ENFORCE(
attributes.count("padding_algorithm") > 0 &&
attributes.at("padding_algorithm").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: padding_algorithm is not right."));
PADDLE_ENFORCE(attributes.count("dilations_t") > 0 &&
attributes.at("dilations_t").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: dilations_t is not right."));
for (size_t i = 0;
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: dilations_t is not right."));
}
PADDLE_ENFORCE(attributes.count("groups") > 0 &&
attributes.at("groups").isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: groups is not right."));
PADDLE_ENFORCE(attributes.count("data_format") > 0 &&
attributes.at("data_format").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: data_format is not right."));
PADDLE_ENFORCE(attributes.count("activation") > 0 &&
attributes.at("activation").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: activation is not right."));
PADDLE_ENFORCE(
attributes.count("exhaustive_search") > 0 &&
attributes.at("exhaustive_search").isa<ir::BoolAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: exhaustive_search is not right."));
PADDLE_ENFORCE(attributes.count("channels") > 0 &&
attributes.at("channels").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: channels is not right."));
for (size_t i = 0;
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i]
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: channels is not right."));
}
PADDLE_ENFORCE(
attributes.count("user_workspace_size") > 0 &&
attributes.at("user_workspace_size").isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: user_workspace_size is not right."));
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 2.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
auto output_1_type = (*this)->result(1).type();
if (auto vec_type = output_1_type.dyn_cast<ir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); i++) {
PADDLE_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th output."));
}
} else {
PADDLE_ENFORCE(output_1_type.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th output."));
}
}
VLOG(4) << "End Verifying for: Conv2dFusionOp.";
}
void Conv2dFusionOpTest::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::FusedConvInferMeta);
fn(infer_meta);
}
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Conv2dFusionOpTest)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Conv2dFusionOpTest)
class Conv2dFusionTestDialect : public ir::Dialect {
public:
explicit Conv2dFusionTestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "con2d fusion test"; }
private:
void initialize() { RegisterOps<paddle::dialect::Conv2dFusionOpTest>(); }
};
IR_DECLARE_EXPLICIT_TYPE_ID(Conv2dFusionTestDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(Conv2dFusionTestDialect)
class Conv2dAddFusePattern
: public ir::OpRewritePattern<paddle::dialect::AddOp> {
public:
using ir::OpRewritePattern<paddle::dialect::AddOp>::OpRewritePattern;
bool MatchAndRewrite(
paddle::dialect::AddOp op,
ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be add.
paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out();
if (!conv2d_out.HasOneUse()) return false;
ir::Value conv2d_filter = conv2d_op.filter();
ir::OpResult conv2d_filter_result = conv2d_filter.dyn_cast<ir::OpResult>();
IR_ENFORCE(conv2d_filter_result);
ir::Value add_input = op.x();
IR_ENFORCE(add_input == conv2d_out);
ir::Value y = op.y();
ir::OpResult bias = y.dyn_cast<ir::OpResult>();
auto conv2d_attributes = conv2d_op.attributes();
std::vector<std::string> conv2d_fusion_attrStr = {"strides",
"paddings_t",
"padding_algorithm",
"dilations_t",
"groups",
"data_format",
"activation",
"exhaustive_search",
"channels",
"user_workspace_size"};
std::vector<ir::Attribute> con2d_fusing_attr = {
conv2d_attributes.at("strides"),
conv2d_attributes.at("paddings"),
conv2d_attributes.at("padding_algorithm"),
conv2d_attributes.at("dilations"),
conv2d_attributes.at("groups"),
conv2d_attributes.at("data_format"),
ir::StrAttribute::get(ir::IrContext::Instance(), "identity"),
ir::BoolAttribute::get(ir::IrContext::Instance(), true),
ir::ArrayAttribute::get(ir::IrContext::Instance(),
std::vector<ir::Attribute>()),
ir::Int32Attribute::get(ir::IrContext::Instance(), int32_t(0)),
};
ir::AttributeMap conv2d_fusion_attributes;
for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) {
conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i];
}
ir::OpResult tmpResidual;
auto conv2d_fuse_op = rewriter.Build<paddle::dialect::Conv2dFusionOpTest>(
ir::GetDefiningOpForInput<0>(conv2d_op)->result(0),
conv2d_filter_result,
bias,
tmpResidual,
conv2d_fusion_attributes);
rewriter.ReplaceOp(op, std::vector<ir::Value>{conv2d_fuse_op.output()});
return true;
}
};
class TestPass : public ir::Pass { class TestPass : public ir::Pass {
public: public:
TestPass() : ir::Pass("TestPass", 1) {} TestPass() : ir::Pass("TestPass", 1) {}
...@@ -336,6 +960,7 @@ class TestPass : public ir::Pass { ...@@ -336,6 +960,7 @@ class TestPass : public ir::Pass {
ir::RewritePatternSet ps(op->ir_context()); ir::RewritePatternSet ps(op->ir_context());
ps.Add<RedundantTransposeFusePattern>(op->ir_context()); ps.Add<RedundantTransposeFusePattern>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context()); ps.Add<Conv2dBnFusePattern>(op->ir_context());
ps.Add<Conv2dAddFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
ir::GreedyRewriteConfig cfg; ir::GreedyRewriteConfig cfg;
...@@ -409,6 +1034,8 @@ void BuildProgram(ir::Builder &builder) { // NOLINT ...@@ -409,6 +1034,8 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
// TODO(wilber): Add a normal test. // TODO(wilber): Add a normal test.
TEST(pattern_rewrite, Patterns) { TEST(pattern_rewrite, Patterns) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
auto *test_dialect = ctx->GetOrRegisterDialect<Conv2dFusionTestDialect>();
test_dialect->RegisterOp<paddle::dialect::Conv2dFusionOpTest>();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block()); ir::Builder builder = ir::Builder(ctx, program.block());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册