未验证 提交 19345fa7 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] add conv + bn fuse pattern, and other works (#54933)

上级 0f69d932
......@@ -347,7 +347,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
return;
}
// conv_weight fp32 --> fp16
// conv_weight fp16 --> fp32
auto* conv_weight_tensor =
scope->FindVar(conv_weight->Name())->GetMutable<phi::DenseTensor>();
auto tensor_type = conv_weight_tensor->dtype();
......
......@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
......
add_subdirectory(interface)
add_subdirectory(dialect)
add_subdirectory(pass)
add_subdirectory(transforms)
add_subdirectory(phi_kernel_adaptor)
# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory.
file(GLOB PD_PASS_SRCS "*.cc")
cc_library(
pd_op_to_kernel_pass
SRCS ${PD_PASS_SRCS}
DEPS ir phi_utils pd_interface)
cc_library(
transform_general_functions
SRCS transform_general_functions.cc
DEPS ir phi pd_dialect)
cc_library(
pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc
DEPS ir phi_utils pd_interface)
......@@ -14,7 +14,7 @@
#include <iostream>
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/kernel_dialect.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
namespace ir {
ir::Parameter* GetParameterFromValue(ir::Value value) {
ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast<ir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
op,
phi::errors::InvalidArgument(
"Value must be a weight from a GetParameter op."));
ir::Program* program = op->GetParentProgram();
std::string name = op->attributes()
.at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data();
return program->GetParameter(name);
}
const phi::DDim& GetShapeFromValue(ir::Value value) {
// TODO(dev): Support other types like DenseTensor.
PADDLE_ENFORCE_EQ(
value.type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument("Value's type must be a DenseTensorType."));
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace ir {
/**
* @brief Get the parameter from a value.
*
* @note The value must be a output of a GetParameterOp.
*
* @param ir::Value
*
* @return ir::Parameter*
*/
ir::Parameter* GetParameterFromValue(ir::Value value);
/**
* @brief Get tensor's shape from a value.
*
* @param ir::Value
*
* @return const phi::DDim&
*/
const phi::DDim& GetShapeFromValue(ir::Value value);
/**
* @brief Get an operation that defines the specific input of the operation.
*
* @param Operation*
*
* @return Operation*
*/
template <uint32_t Index = 0>
Operation* GetDefiningOpForInput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(Index).GetDefiningOp();
}
/**
* @brief Get an operation that is the first to use the specific output of the
* operation.
*
* @param Operation*
*
* @return Operation*
*/
template <uint32_t Index = 0>
Operation* GetFirstUseOperationForOutput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_results(),
true,
phi::errors::InvalidArgument("Output op result's index must be valid."));
return op->result(Index).first_use().owner();
}
} // namespace ir
......@@ -107,6 +107,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
// Call destructors for Region , OpResults, Operation, and OpOperands in
// sequence, and finally free memory.
void Operation::Destroy() {
VLOG(6) << "Destroy Operation [" << name() << "] ...";
// 1. Deconstruct Regions.
if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) {
......@@ -117,7 +118,8 @@ void Operation::Destroy() {
// 2. Deconstruct Result.
for (size_t idx = 0; idx < num_results_; ++idx) {
detail::OpResultImpl *impl = result(idx).impl();
IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses.");
IR_ENFORCE(impl->use_empty(),
name() + " operation destroyed but still has uses.");
if (detail::OpOutlineResultImpl::classof(*impl)) {
static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
} else {
......@@ -143,8 +145,8 @@ void Operation::Destroy() {
: sizeof(detail::OpInlineResultImpl) * num_results_;
void *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr
<< ", size = " << result_mem_size << "}";
VLOG(6) << "Destroy Operation [" << name() << "]: {ptr = " << aligned_ptr
<< ", size = " << result_mem_size << "} done.";
aligned_free(aligned_ptr);
}
......
......@@ -27,14 +27,14 @@ Program::~Program() {
}
}
Parameter* Program::GetParameter(std::string name) const {
Parameter* Program::GetParameter(const std::string& name) const {
if (parameters_.count(name) != 0) {
return parameters_.at(name).get();
}
return nullptr;
}
void Program::SetParameter(std::string name,
void Program::SetParameter(const std::string& name,
std::unique_ptr<Parameter>&& parameter) {
parameters_[name].reset(parameter.release());
}
......
......@@ -54,8 +54,9 @@ class IR_API Program {
Block* block() { return module_.block(); }
Parameter* GetParameter(std::string name) const;
void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter);
Parameter* GetParameter(const std::string& name) const;
void SetParameter(const std::string& name,
std::unique_ptr<Parameter>&& parameter);
ParameterMap& parameters() { return parameters_; }
void set_parameters(ParameterMap&& parameters) {
......
......@@ -18,6 +18,7 @@
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/verify.h"
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
......@@ -109,10 +110,9 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
bool pass_failed = pass->pass_state().pass_failed;
// TODO(liuyuanle): Support verification of operation
if (!pass_failed && verify) {
// bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
// pass_failed = ir::Verify(op, verify_recursively);
bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
ir::Verify(op, verify_recursively);
}
return !pass_failed;
......
......@@ -274,7 +274,7 @@ class RewriterBase : public Builder {
virtual void EraseOp(Operation* op);
void ReplaceAllUsesWith(Value from, Value to);
IR_API void ReplaceAllUsesWith(Value from, Value to);
void ReplaceUseIf(Value from,
Value to,
......
......@@ -42,8 +42,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/phi/core/kernel_registry.h"
......
......@@ -26,8 +26,8 @@
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
......
......@@ -94,7 +94,8 @@ IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
struct CountOpAnalysis {
explicit CountOpAnalysis(ir::Operation *container_op) {
IR_ENFORCE(container_op->num_regions() > 0, true);
IR_ENFORCE(container_op->num_regions() > 0,
"op must be a container with zero or multiple regions.");
LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n";
for (size_t i = 0; i < container_op->num_regions(); ++i) {
......
......@@ -5,4 +5,5 @@ cc_test_old(
DEPS
ir
pd_dialect
transform_general_functions
gtest)
......@@ -13,12 +13,14 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <numeric>
#include <sstream>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
......@@ -28,7 +30,9 @@
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
......@@ -39,9 +43,11 @@
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/phi/core/ddim.h"
// Define op1.
class Operation1 : public ir::Op<Operation1> {
......@@ -53,6 +59,7 @@ class Operation1 : public ir::Op<Operation1> {
void Verify();
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
void Operation1::Verify() {
auto &attributes = this->attributes();
if (attributes.count("op2_attr1") == 0 ||
......@@ -183,7 +190,7 @@ class TransposePatternRewrite
bool MatchAndRewrite(paddle::dialect::TransposeOp op,
ir::PatternRewriter &rewriter) const override {
auto prev_op = op->operand(0).GetDefiningOp();
auto prev_op = ir::GetDefiningOpForInput<0>(op);
std::vector<int> axis_last = GetAxis(op);
auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>();
if (prev_trans_op) {
......@@ -192,9 +199,9 @@ class TransposePatternRewrite
"tranpose op's perm rank should be same.");
auto new_perm = GetPerm(axis_first, axis_last);
rewriter.SetInsertionPoint(op);
auto new_op = rewriter.Build<paddle::dialect::TransposeOp>(
prev_op->operand(0).GetDefiningOp()->result(0), new_perm);
rewriter.ReplaceOp(op, {new_op.out()});
auto new_transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir::GetDefiningOpForInput<0>(prev_trans_op)->result(0), new_perm);
rewriter.ReplaceOp(op, {new_transpose_op.out()});
return true;
}
......@@ -203,9 +210,7 @@ class TransposePatternRewrite
private:
std::vector<int> GetAxis(paddle::dialect::TransposeOp op) const {
auto attr_map = op->attributes();
ir::ArrayAttribute array_attr =
attr_map.at("perm").dyn_cast<ir::ArrayAttribute>();
auto array_attr = op.attribute<ir::ArrayAttribute>("perm").data();
std::vector<int> axis(array_attr.size());
for (size_t i = 0; i < array_attr.size(); ++i) {
axis[i] = array_attr[i].dyn_cast<ir::Int32Attribute>().data();
......@@ -228,12 +233,121 @@ class TransposePatternRewrite
}
};
class Conv2dBnFusePattern
: public ir::OpRewritePattern<paddle::dialect::BatchNormOp> {
public:
using ir::OpRewritePattern<paddle::dialect::BatchNormOp>::OpRewritePattern;
bool MatchAndRewrite(
paddle::dialect::BatchNormOp op,
ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be batch_norm.
paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out();
if (!conv2d_out.HasOneUse()) return false;
ir::Value conv2d_filter = conv2d_op.filter();
// ir::GetParameterOp filter_parameter_op =
// conv2d_filter.GetDefiningOp()->dyn_cast<ir::GetParameterOp>();
// if (!filter_parameter_op) return false;
ir::OpResult conv2d_filter_result = conv2d_filter.dyn_cast<ir::OpResult>();
IR_ENFORCE(conv2d_filter_result);
ir::Value bn_input = op.x();
IR_ENFORCE(bn_input == conv2d_out);
ir::Value bn_mean = op.mean();
ir::Value bn_variance = op.variance();
ir::Value bn_scale = op.scale();
ir::Value bn_bias = op.bias();
ir::OpResult bn_mean_result = bn_mean.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_mean_result);
ir::OpResult bn_variance_result = bn_variance.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_variance_result);
ir::OpResult bn_scale_result = bn_scale.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_scale_result);
ir::OpResult bn_bias_result = bn_bias.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_bias_result);
// --- deal with filter ---
rewriter.SetInsertionPoint(conv2d_op);
phi::DDim bn_variance_shape =
bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
float epsilon = op.attribute<ir::FloatAttribute>("epsilon").data();
paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>(
phi::vectorize(bn_variance_shape), epsilon);
paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>(
bn_variance_result, full_op.out());
paddle::dialect::SqrtOp sqrt_op =
rewriter.Build<paddle::dialect::SqrtOp>(add_op.out());
paddle::dialect::DivideOp div_op =
rewriter.Build<paddle::dialect::DivideOp>(bn_scale_result,
sqrt_op.out());
// reshape scale
phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter);
phi::DDim bn_scale_shape =
bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1);
bn_scale_new_shape[0] = bn_scale_shape[0];
paddle::dialect::ReshapeOp reshape_scale_op =
rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(),
bn_scale_new_shape);
// new filter --> mul_op.out()
paddle::dialect::MultiplyOp mul_op =
rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result,
reshape_scale_op.out());
// TODO(liuyuanle): Use rewriter.
conv2d_op->op_operand(1).set_source(mul_op.out());
// --- deal with bias ---
rewriter.SetInsertionPoint(op);
paddle::dialect::MultiplyOp mul_bias_op =
rewriter.Build<paddle::dialect::MultiplyOp>(bn_mean_result,
div_op.out());
// new bias --> sub_op.out()
paddle::dialect::SubtractOp sub_op =
rewriter.Build<paddle::dialect::SubtractOp>(bn_bias_result,
mul_bias_op.out());
// reshape new bias
phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out);
std::vector<int64_t> new_bias_new_shape(conv2d_out_shape.size(), 1);
std::string data_format =
conv2d_op.attribute<ir::StrAttribute>("data_format").data();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[0] = conv2d_out_shape[0];
new_bias_new_shape[1] = conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
new_bias_new_shape);
paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>(
conv2d_out, reshape_bias_op.out());
auto next_op = ir::GetFirstUseOperationForOutput<0>(op);
rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out());
rewriter.EraseOp(op);
return true;
}
};
class TestPass : public ir::Pass {
public:
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context());
ps.Add<TransposePatternRewrite>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
......@@ -247,15 +361,55 @@ class TestPass : public ir::Pass {
};
void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_op =
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 3, 16, 16},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
ir::OpResult full_op_output = full_op->result(0);
paddle::dialect::FullOp full_filter_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 3, 3, 3},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_mean_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp full_variance_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_scale_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_bias_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::Conv2dOp conv2d_op =
builder.Build<paddle::dialect::Conv2dOp>(full_input_op.out(),
full_filter_op.out());
paddle::dialect::BatchNormOp batch_norm_op =
builder.Build<paddle::dialect::BatchNormOp>(conv2d_op.out(),
full_mean_op.out(),
full_variance_op.out(),
full_scale_op.out(),
full_bias_op.out(),
true,
0.9,
1e-6,
"NCHW",
false,
false);
auto transpose1_op = builder.Build<paddle::dialect::TransposeOp>(
full_op_output, std::vector<int>{0, 2, 3, 1});
batch_norm_op.out(), std::vector<int>{0, 2, 3, 1});
auto transpose2_op = builder.Build<paddle::dialect::TransposeOp>(
transpose1_op.out(), std::vector<int>{0, 3, 1, 2});
......@@ -264,22 +418,22 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
}
// TODO(wilber): Add a normal test.
TEST(PatternRewrite, GreedyPatternRewriteDriver) {
TEST(pattern_rewrite, Patterns) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);
EXPECT_EQ(program.block()->size(), 4u);
EXPECT_EQ(program.block()->size(), 11u);
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass());
std::stringstream o1, o2;
program.Print(o1);
LOG(INFO) << o1.str();
program.Print(std::cout);
std::cout << std::endl;
pm.Run(&program);
LOG(INFO) << "After Pass.";
program.Print(o2);
LOG(INFO) << o2.str();
program.Print(std::cout);
std::cout << std::endl;
}
......@@ -24,7 +24,7 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册