diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index dc50cff7df52ed8425a5efb60164e1abc56f412b..0392a8cfba0bed13c54827fba708ba3a0be4a27e 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -347,7 +347,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { return; } - // conv_weight fp32 --> fp16 + // conv_weight fp16 --> fp32 auto* conv_weight_tensor = scope->FindVar(conv_weight->Name())->GetMutable(); auto tensor_type = conv_weight_tensor->dtype(); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index cb00d4429ab14358f6cfaebb524a6d3ac99c357d..aed9150b547ab9414a264213d70f224cf446e72a 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" diff --git a/paddle/fluid/ir/CMakeLists.txt b/paddle/fluid/ir/CMakeLists.txt index f4e88d5dc1e22c3f7c5701d39cf1e6dd86b8f4c9..d778d43352881776b5375e8a5c9cc8603d589b81 100644 --- a/paddle/fluid/ir/CMakeLists.txt +++ b/paddle/fluid/ir/CMakeLists.txt @@ -1,4 +1,4 @@ add_subdirectory(interface) add_subdirectory(dialect) -add_subdirectory(pass) +add_subdirectory(transforms) add_subdirectory(phi_kernel_adaptor) diff --git a/paddle/fluid/ir/pass/CMakeLists.txt b/paddle/fluid/ir/pass/CMakeLists.txt deleted file mode 100644 index f67add0298669f006fb4dcb6ebc51ff411107128..0000000000000000000000000000000000000000 --- a/paddle/fluid/ir/pass/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory. -file(GLOB PD_PASS_SRCS "*.cc") - -cc_library( - pd_op_to_kernel_pass - SRCS ${PD_PASS_SRCS} - DEPS ir phi_utils pd_interface) diff --git a/paddle/fluid/ir/transforms/CMakeLists.txt b/paddle/fluid/ir/transforms/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d131b458dddf1cd04346c0f4d8c7f7afc00b9734 --- /dev/null +++ b/paddle/fluid/ir/transforms/CMakeLists.txt @@ -0,0 +1,9 @@ +cc_library( + transform_general_functions + SRCS transform_general_functions.cc + DEPS ir phi pd_dialect) + +cc_library( + pd_op_to_kernel_pass + SRCS pd_op_to_kernel_pass.cc + DEPS ir phi_utils pd_interface) diff --git a/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc similarity index 99% rename from paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc rename to paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 740c0afb4423100e3ca4be811c523d4a43d820e6..ca43cef77e4b7def85da40a24838055239cb1fb4 100644 --- a/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/dialect/kernel_attribute.h" #include "paddle/fluid/ir/dialect/kernel_dialect.h" diff --git a/paddle/fluid/ir/pass/pd_op_to_kernel_pass.h b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h similarity index 100% rename from paddle/fluid/ir/pass/pd_op_to_kernel_pass.h rename to paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h diff --git a/paddle/fluid/ir/transforms/transform_general_functions.cc b/paddle/fluid/ir/transforms/transform_general_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca2a9ccc6d521c3a9d9ba6e75ad30670e5ac4ea9 --- /dev/null +++ b/paddle/fluid/ir/transforms/transform_general_functions.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir/transforms/transform_general_functions.h" + +#include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/dialect/pd_type.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" + +namespace ir { + +ir::Parameter* GetParameterFromValue(ir::Value value) { + ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast(); + PADDLE_ENFORCE_NOT_NULL( + op, + phi::errors::InvalidArgument( + "Value must be a weight from a GetParameter op.")); + ir::Program* program = op->GetParentProgram(); + std::string name = op->attributes() + .at(op.attributes_name[0]) + .dyn_cast() + .data(); + return program->GetParameter(name); +} + +const phi::DDim& GetShapeFromValue(ir::Value value) { + // TODO(dev): Support other types like DenseTensor. + PADDLE_ENFORCE_EQ( + value.type().isa(), + true, + phi::errors::InvalidArgument("Value's type must be a DenseTensorType.")); + return value.type().dyn_cast().dims(); +} + +} // namespace ir diff --git a/paddle/fluid/ir/transforms/transform_general_functions.h b/paddle/fluid/ir/transforms/transform_general_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..69dafbe1517a11774bfa94dcdf13c714df368f88 --- /dev/null +++ b/paddle/fluid/ir/transforms/transform_general_functions.h @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/parameter.h" +#include "paddle/ir/core/value.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace ir { + +/** + * @brief Get the parameter from a value. + * + * @note The value must be a output of a GetParameterOp. + * + * @param ir::Value + * + * @return ir::Parameter* + */ +ir::Parameter* GetParameterFromValue(ir::Value value); + +/** + * @brief Get tensor's shape from a value. + * + * @param ir::Value + * + * @return const phi::DDim& + */ +const phi::DDim& GetShapeFromValue(ir::Value value); + +/** + * @brief Get an operation that defines the specific input of the operation. + * + * @param Operation* + * + * @return Operation* + */ +template +Operation* GetDefiningOpForInput(Operation* op) { + PADDLE_ENFORCE_EQ( + Index < op->num_operands(), + true, + phi::errors::InvalidArgument("Intput operand's index must be valid.")); + return op->operand(Index).GetDefiningOp(); +} + +/** + * @brief Get an operation that is the first to use the specific output of the + * operation. + * + * @param Operation* + * + * @return Operation* + */ +template +Operation* GetFirstUseOperationForOutput(Operation* op) { + PADDLE_ENFORCE_EQ( + Index < op->num_results(), + true, + phi::errors::InvalidArgument("Output op result's index must be valid.")); + return op->result(Index).first_use().owner(); +} + +} // namespace ir diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 0cdfe349d56508b3052502ebdaf620e2232ba891..5348ef81ef962d87385eca5c2aaea3085aaef898 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -107,6 +107,7 @@ Operation *Operation::Create(const std::vector &inputs, // Call destructors for Region , OpResults, Operation, and OpOperands in // sequence, and finally free memory. void Operation::Destroy() { + VLOG(6) << "Destroy Operation [" << name() << "] ..."; // 1. Deconstruct Regions. if (num_regions_ > 0) { for (size_t idx = 0; idx < num_regions_; idx++) { @@ -117,7 +118,8 @@ void Operation::Destroy() { // 2. Deconstruct Result. for (size_t idx = 0; idx < num_results_; ++idx) { detail::OpResultImpl *impl = result(idx).impl(); - IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses."); + IR_ENFORCE(impl->use_empty(), + name() + " operation destroyed but still has uses."); if (detail::OpOutlineResultImpl::classof(*impl)) { static_cast(impl)->~OpOutlineResultImpl(); } else { @@ -143,8 +145,8 @@ void Operation::Destroy() { : sizeof(detail::OpInlineResultImpl) * num_results_; void *aligned_ptr = reinterpret_cast(this) - result_mem_size; - VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr - << ", size = " << result_mem_size << "}"; + VLOG(6) << "Destroy Operation [" << name() << "]: {ptr = " << aligned_ptr + << ", size = " << result_mem_size << "} done."; aligned_free(aligned_ptr); } diff --git a/paddle/ir/core/program.cc b/paddle/ir/core/program.cc index 1e1404fa3ca4e7257382653f75a3ed545af3f6ce..baf6a3cbdd57cea0539d3d78841ea79feb9657dd 100644 --- a/paddle/ir/core/program.cc +++ b/paddle/ir/core/program.cc @@ -27,14 +27,14 @@ Program::~Program() { } } -Parameter* Program::GetParameter(std::string name) const { +Parameter* Program::GetParameter(const std::string& name) const { if (parameters_.count(name) != 0) { return parameters_.at(name).get(); } return nullptr; } -void Program::SetParameter(std::string name, +void Program::SetParameter(const std::string& name, std::unique_ptr&& parameter) { parameters_[name].reset(parameter.release()); } diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 16ff15635f7df82b7edb6b2352263d2ac7ab27c9..a65142b2531e09ec0255218ef80ed2b8f8120608 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -54,8 +54,9 @@ class IR_API Program { Block* block() { return module_.block(); } - Parameter* GetParameter(std::string name) const; - void SetParameter(std::string name, std::unique_ptr&& parameter); + Parameter* GetParameter(const std::string& name) const; + void SetParameter(const std::string& name, + std::unique_ptr&& parameter); ParameterMap& parameters() { return parameters_; } void set_parameters(ParameterMap&& parameters) { diff --git a/paddle/ir/pass/pass.cc b/paddle/ir/pass/pass.cc index 0186ea892f0d6a57fe315af1a4ffd886f2a8d5d6..46fb273249c101f5048368d26371ae25834c6016 100644 --- a/paddle/ir/pass/pass.cc +++ b/paddle/ir/pass/pass.cc @@ -18,6 +18,7 @@ #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/region.h" +#include "paddle/ir/core/verify.h" #include "paddle/ir/pass/pass_adaptor.h" #include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/pass_manager.h" @@ -109,10 +110,9 @@ bool detail::PassAdaptor::RunPass(Pass* pass, bool pass_failed = pass->pass_state().pass_failed; - // TODO(liuyuanle): Support verification of operation if (!pass_failed && verify) { - // bool verify_recursively = !dynamic_cast(pass); - // pass_failed = ir::Verify(op, verify_recursively); + bool verify_recursively = !dynamic_cast(pass); + ir::Verify(op, verify_recursively); } return !pass_failed; diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h index 6c90f366564fc9c28813b80a0368971375c8a749..26a6a1842b91c2ef9c438254877e9686fd52bdad 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.h +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -274,7 +274,7 @@ class RewriterBase : public Builder { virtual void EraseOp(Operation* op); - void ReplaceAllUsesWith(Value from, Value to); + IR_API void ReplaceAllUsesWith(Value from, Value to); void ReplaceUseIf(Value from, Value to, diff --git a/test/cpp/ir/core/ir_exe_test.cc b/test/cpp/ir/core/ir_exe_test.cc index 0b3f956cd47c9f386f0a888b00e80d36bc182e82..ee10a75e1184f789c2fa983597975f92aa360992 100644 --- a/test/cpp/ir/core/ir_exe_test.cc +++ b/test/cpp/ir/core/ir_exe_test.cc @@ -42,8 +42,8 @@ #include "paddle/fluid/ir/dialect/pd_attribute.h" -#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/ir/core/attribute.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc index 5deb7ae9b8ef85eb9ea1ec5261e4f3d20234d711..3727ce4bd47edceec6818397784376d7e536e928 100644 --- a/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -26,8 +26,8 @@ #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" -#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/init.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index b77df8a092097d6b8afb86afe0b56c92cff5a77a..501249dd6f517e71286be100a6691ee6853af5cf 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -94,7 +94,8 @@ IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) struct CountOpAnalysis { explicit CountOpAnalysis(ir::Operation *container_op) { - IR_ENFORCE(container_op->num_regions() > 0, true); + IR_ENFORCE(container_op->num_regions() > 0, + "op must be a container with zero or multiple regions."); LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n"; for (size_t i = 0; i < container_op->num_regions(); ++i) { diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt index 62dfa3b8dece57bd6cb6a7447777593facfa94a4..f3e2cbbee22f410937230c4bf04ff4e784099f66 100644 --- a/test/cpp/ir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -5,4 +5,5 @@ cc_test_old( DEPS ir pd_dialect + transform_general_functions gtest) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 068a78be5e510c683f298f4095a501c40894ec14..8a8a73b09345965d651db2a43952607111a3f5ce 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -13,12 +13,14 @@ // limitations under the License. #include +#include #include #include #include #include #include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/fluid/ir/transforms/transform_general_functions.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" @@ -28,7 +30,9 @@ #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_info.h" +#include "paddle/ir/core/parameter.h" #include "paddle/ir/core/program.h" +#include "paddle/ir/core/value.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" @@ -39,9 +43,11 @@ // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. -#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_op.h" + +#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" +#include "paddle/phi/core/ddim.h" // Define op1. class Operation1 : public ir::Op { @@ -53,6 +59,7 @@ class Operation1 : public ir::Op { void Verify(); static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; + void Operation1::Verify() { auto &attributes = this->attributes(); if (attributes.count("op2_attr1") == 0 || @@ -183,7 +190,7 @@ class TransposePatternRewrite bool MatchAndRewrite(paddle::dialect::TransposeOp op, ir::PatternRewriter &rewriter) const override { - auto prev_op = op->operand(0).GetDefiningOp(); + auto prev_op = ir::GetDefiningOpForInput<0>(op); std::vector axis_last = GetAxis(op); auto prev_trans_op = prev_op->dyn_cast(); if (prev_trans_op) { @@ -192,9 +199,9 @@ class TransposePatternRewrite "tranpose op's perm rank should be same."); auto new_perm = GetPerm(axis_first, axis_last); rewriter.SetInsertionPoint(op); - auto new_op = rewriter.Build( - prev_op->operand(0).GetDefiningOp()->result(0), new_perm); - rewriter.ReplaceOp(op, {new_op.out()}); + auto new_transpose_op = rewriter.Build( + ir::GetDefiningOpForInput<0>(prev_trans_op)->result(0), new_perm); + rewriter.ReplaceOp(op, {new_transpose_op.out()}); return true; } @@ -203,9 +210,7 @@ class TransposePatternRewrite private: std::vector GetAxis(paddle::dialect::TransposeOp op) const { - auto attr_map = op->attributes(); - ir::ArrayAttribute array_attr = - attr_map.at("perm").dyn_cast(); + auto array_attr = op.attribute("perm").data(); std::vector axis(array_attr.size()); for (size_t i = 0; i < array_attr.size(); ++i) { axis[i] = array_attr[i].dyn_cast().data(); @@ -228,12 +233,121 @@ class TransposePatternRewrite } }; +class Conv2dBnFusePattern + : public ir::OpRewritePattern { + public: + using ir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite( + paddle::dialect::BatchNormOp op, + ir::PatternRewriter &rewriter) const override { // NOLINT + // The next op should be batch_norm. + paddle::dialect::Conv2dOp conv2d_op = + ir::GetDefiningOpForInput(op)->dyn_cast(); + if (!conv2d_op) return false; + + ir::OpResult conv2d_out = conv2d_op.out(); + if (!conv2d_out.HasOneUse()) return false; + + ir::Value conv2d_filter = conv2d_op.filter(); + + // ir::GetParameterOp filter_parameter_op = + // conv2d_filter.GetDefiningOp()->dyn_cast(); + // if (!filter_parameter_op) return false; + + ir::OpResult conv2d_filter_result = conv2d_filter.dyn_cast(); + IR_ENFORCE(conv2d_filter_result); + + ir::Value bn_input = op.x(); + IR_ENFORCE(bn_input == conv2d_out); + + ir::Value bn_mean = op.mean(); + ir::Value bn_variance = op.variance(); + ir::Value bn_scale = op.scale(); + ir::Value bn_bias = op.bias(); + + ir::OpResult bn_mean_result = bn_mean.dyn_cast(); + IR_ENFORCE(bn_mean_result); + ir::OpResult bn_variance_result = bn_variance.dyn_cast(); + IR_ENFORCE(bn_variance_result); + ir::OpResult bn_scale_result = bn_scale.dyn_cast(); + IR_ENFORCE(bn_scale_result); + ir::OpResult bn_bias_result = bn_bias.dyn_cast(); + IR_ENFORCE(bn_bias_result); + + // --- deal with filter --- + rewriter.SetInsertionPoint(conv2d_op); + phi::DDim bn_variance_shape = + bn_variance.type().dyn_cast().dims(); + float epsilon = op.attribute("epsilon").data(); + paddle::dialect::FullOp full_op = rewriter.Build( + phi::vectorize(bn_variance_shape), epsilon); + paddle::dialect::AddOp add_op = rewriter.Build( + bn_variance_result, full_op.out()); + paddle::dialect::SqrtOp sqrt_op = + rewriter.Build(add_op.out()); + paddle::dialect::DivideOp div_op = + rewriter.Build(bn_scale_result, + sqrt_op.out()); + + // reshape scale + phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter); + phi::DDim bn_scale_shape = + bn_scale.type().dyn_cast().dims(); + std::vector bn_scale_new_shape(conv2d_filter_shape.size(), 1); + bn_scale_new_shape[0] = bn_scale_shape[0]; + + paddle::dialect::ReshapeOp reshape_scale_op = + rewriter.Build(div_op.out(), + bn_scale_new_shape); + // new filter --> mul_op.out() + paddle::dialect::MultiplyOp mul_op = + rewriter.Build(conv2d_filter_result, + reshape_scale_op.out()); + // TODO(liuyuanle): Use rewriter. + conv2d_op->op_operand(1).set_source(mul_op.out()); + + // --- deal with bias --- + rewriter.SetInsertionPoint(op); + paddle::dialect::MultiplyOp mul_bias_op = + rewriter.Build(bn_mean_result, + div_op.out()); + // new bias --> sub_op.out() + paddle::dialect::SubtractOp sub_op = + rewriter.Build(bn_bias_result, + mul_bias_op.out()); + + // reshape new bias + phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out); + std::vector new_bias_new_shape(conv2d_out_shape.size(), 1); + std::string data_format = + conv2d_op.attribute("data_format").data(); + + IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); + new_bias_new_shape[0] = conv2d_out_shape[0]; + new_bias_new_shape[1] = conv2d_out_shape[1]; + + paddle::dialect::ReshapeOp reshape_bias_op = + rewriter.Build(sub_op.out(), + new_bias_new_shape); + + paddle::dialect::AddOp add_bias_op = rewriter.Build( + conv2d_out, reshape_bias_op.out()); + auto next_op = ir::GetFirstUseOperationForOutput<0>(op); + rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out()); + + rewriter.EraseOp(op); + return true; + } +}; + class TestPass : public ir::Pass { public: TestPass() : ir::Pass("TestPass", 1) {} void Run(ir::Operation *op) override { ir::RewritePatternSet ps(op->ir_context()); ps.Add(op->ir_context()); + ps.Add(op->ir_context()); + ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); ir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; @@ -247,15 +361,55 @@ class TestPass : public ir::Pass { }; void BuildProgram(ir::Builder &builder) { // NOLINT - paddle::dialect::FullOp full_op = + paddle::dialect::FullOp full_input_op = builder.Build(std::vector{1, 3, 16, 16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); - ir::OpResult full_op_output = full_op->result(0); + + paddle::dialect::FullOp full_filter_op = + builder.Build(std::vector{64, 3, 3, 3}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_mean_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::FullOp full_variance_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_scale_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_bias_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::Conv2dOp conv2d_op = + builder.Build(full_input_op.out(), + full_filter_op.out()); + + paddle::dialect::BatchNormOp batch_norm_op = + builder.Build(conv2d_op.out(), + full_mean_op.out(), + full_variance_op.out(), + full_scale_op.out(), + full_bias_op.out(), + true, + 0.9, + 1e-6, + "NCHW", + false, + false); auto transpose1_op = builder.Build( - full_op_output, std::vector{0, 2, 3, 1}); + batch_norm_op.out(), std::vector{0, 2, 3, 1}); auto transpose2_op = builder.Build( transpose1_op.out(), std::vector{0, 3, 1, 2}); @@ -264,22 +418,22 @@ void BuildProgram(ir::Builder &builder) { // NOLINT } // TODO(wilber): Add a normal test. -TEST(PatternRewrite, GreedyPatternRewriteDriver) { +TEST(pattern_rewrite, Patterns) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program(ctx); ir::Builder builder = ir::Builder(ctx, program.block()); BuildProgram(builder); - EXPECT_EQ(program.block()->size(), 4u); + + EXPECT_EQ(program.block()->size(), 11u); ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); pm.AddPass(ir::CreateDCEPass()); - std::stringstream o1, o2; - program.Print(o1); - LOG(INFO) << o1.str(); + program.Print(std::cout); + std::cout << std::endl; pm.Run(&program); LOG(INFO) << "After Pass."; - program.Print(o2); - LOG(INFO) << o2.str(); + program.Print(std::cout); + std::cout << std::endl; } diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 0a4a5ab0428da721f18168671b32b89e0b3a6f84..c08c590b7735c25cef13e1bcdf65ca5ba60675ae 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_op.h" -#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/program.h"