未验证 提交 4905a247 编写于 作者: Y Yuanle Liu 提交者: GitHub

[PASS] add constant folding pass (#55099)

上级 df311526
......@@ -139,7 +139,6 @@ void HandleForSpecialOp(ir::Operation* op,
if (op_name == "pd.fetch") {
// fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:";
for (size_t i = 0; i < input_num; ++i) {
auto var = scope->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << scope;
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
......@@ -147,7 +146,6 @@ void HandleForSpecialOp(ir::Operation* op,
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
}
}
if (op_name == "pd.feed") {
VLOG(6) << "Handle for pd.feed:";
......
......@@ -7,3 +7,9 @@ cc_library(
pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc
DEPS phi_utils pd_interface pd_trait ir)
cc_library(
_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor phi pd_op_to_kernel_pass transform_general_functions
ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace {
class ConstantFoldingPattern : public ir::RewritePattern {
public:
ConstantFoldingPattern(ir::IrContext* context,
ir::PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) {
}
bool Match(ir::Operation* op) const override {
// TODO(liuyuanle): Use trait to improve robustness.
if (op->dyn_cast<ir::GetParameterOp>() ||
op->dyn_cast<ir::SetParameterOp>() ||
op->dyn_cast<paddle::dialect::FetchOp>())
return false;
// Inputs must come from get parameter op.
for (uint32_t i = 0; i < op->num_operands(); ++i)
if (ir::GetDefiningOpForInput(op, i)->dyn_cast<ir::GetParameterOp>() ==
nullptr)
return false;
return true;
}
void Rewrite(ir::Operation* op,
ir::PatternRewriter& rewriter) const override { // NOLINT
ir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op);
// Execute program
paddle::framework::interpreter::ExecutionConfig exe_config;
exe_config.create_local_scope = false;
paddle::framework::InterpreterCore core(
phi::CPUPlace{},
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_,
exe_config);
paddle::framework::FetchList fetch_list = core.Run({});
// TODO(liuyuanle): Support multiple output.
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
std::unique_ptr<ir::Parameter> parameter = std::make_unique<ir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op->result(0).type());
std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
auto* param_var = scope_.Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
*param_tensor = out_tensor;
program->SetParameter(param_name, std::move(parameter));
// rewriter.SetInsertionPoint(op);
auto get_parameter_op =
rewriter.Build<ir::GetParameterOp>(param_name, op->result(0).type());
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}
private:
std::unique_ptr<ir::Program> BuildProgramFromOperation(
ir::Operation* op) const {
auto program = std::make_unique<ir::Program>(ir_context());
ir::Builder builder = ir::Builder(ir_context(), program->block());
// prepare op inputs
std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));
auto [param_name, param] = ir::GetParameterFromValue(op->operand(i));
program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param));
auto* param_var = scope_.FindVar(param_name);
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
auto get_parameter_op =
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type());
op_inputs.push_back(get_parameter_op->result(0));
}
// prepare op outputs
std::vector<ir::Type> output_types;
for (uint32_t i = 0; i < op->num_results(); i++) {
output_types.push_back(op->result(i).type());
}
auto* temp_op =
builder.Build(op_inputs, op->attributes(), output_types, op->info());
// TODO(liuyuanle): Support multiple output.
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ(
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's output must be a dense tensor type."));
builder.Build<paddle::dialect::FetchOp>(
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
// }
return program;
}
private:
static size_t suffix_;
static paddle::framework::Scope scope_;
};
size_t ConstantFoldingPattern::suffix_ = 0;
paddle::framework::Scope ConstantFoldingPattern::scope_ = {};
class ConstantFoldingPass : public ir::Pass {
public:
// TODO(liuyuanle): Naming convention for pass.
ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {}
bool Initialize(ir::IrContext* context) override {
ir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(context);
patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
}
void Run(ir::Operation* op) override {
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}
bool CanApplyOn(ir::Operation* op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
private:
ir::FrozenRewritePatternSet patterns_;
};
} // namespace
namespace ir {
std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
}
} // namespace ir
......@@ -18,8 +18,9 @@
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateDcePass();
IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
} // namespace ir
......@@ -17,22 +17,28 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
namespace ir {
ir::Parameter* GetParameterFromValue(ir::Value value) {
std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) {
ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast<ir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
op,
phi::errors::InvalidArgument(
"Value must be a weight from a GetParameter op."));
ir::Program* program = op->GetParentProgram();
PADDLE_ENFORCE_NOT_NULL(
program, phi::errors::InvalidArgument("Program should not be null."));
std::string name = op->attributes()
.at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data();
return program->GetParameter(name);
ir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
return {name, param};
}
const phi::DDim& GetShapeFromValue(ir::Value value) {
......@@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) {
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
}
ir::Type GetDataTypeFromValue(ir::Value value) {
// TODO(dev): Support other types like DenseTensor.
PADDLE_ENFORCE_EQ(
value.type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument("Value's type must be a DenseTensorType."));
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
}
Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
PADDLE_ENFORCE_EQ(
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp();
}
Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
PADDLE_ENFORCE_EQ(
index < op->num_results(),
true,
phi::errors::InvalidArgument("Output op result's index must be valid."));
return op->result(index).first_use().owner();
}
} // namespace ir
......@@ -16,6 +16,7 @@
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
......@@ -24,15 +25,16 @@
namespace ir {
/**
* @brief Get the parameter from a value.
* @brief Get the [name, parameter] pair of pararmeter from a value.
*
* @note The value must be a output of a GetParameterOp.
*
* @param ir::Value
*
* @return ir::Parameter*
* @return std::pair<std::string, ir::Parameter*>
*/
ir::Parameter* GetParameterFromValue(ir::Value value);
std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value);
/**
* @brief Get tensor's shape from a value.
......@@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value);
*/
const phi::DDim& GetShapeFromValue(ir::Value value);
/**
* @brief Get tensor's data type from a value.
*
* @param ir::Value
*
* @return ir::Type
*/
ir::Type GetDataTypeFromValue(ir::Value value);
/**
* @brief Get an operation that defines the specific input of the operation.
*
* @param Operation*
* @param Operation* pointer to an operation
* @param uint32_t index of operand of the operation
*
* @return Operation*
*/
template <uint32_t Index = 0>
Operation* GetDefiningOpForInput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(Index).GetDefiningOp();
}
Operation* GetDefiningOpForInput(Operation* op, uint32_t index);
/**
* @brief Get an operation that is the first to use the specific output of the
* operation.
*
* @param Operation*
*
* @param Operation* pointer to an operation
* @param uint32_t index of result of the operation
* @return Operation*
*/
template <uint32_t Index = 0>
Operation* GetFirstUseOperationForOutput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_results(),
true,
phi::errors::InvalidArgument("Output op result's index must be valid."));
return op->result(Index).first_use().owner();
}
Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index);
} // namespace ir
......@@ -37,7 +37,7 @@ endfunction()
add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
add_subdirectory(transforms)
add_subdirectory(builtin_transforms)
if(WIN32)
if(WITH_SHARED_IR)
......
......@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/transforms/dce.h"
#include <memory>
#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
namespace {
......@@ -22,30 +23,43 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class DcePass : public ir::Pass {
class DeadCodeEliminationPass : public ir::Pass {
public:
DcePass() : ir::Pass("DcePass", 0) {}
DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {}
void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block();
std::vector<ir::Operation> erased_op;
std::vector<ir::Operation *> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) {
auto &op = *it;
// TODO(wilber): Support NoSideEffect trait.
// if (!(*it)->HasTrait<NoSideEffect>()) continue;
// if (!op->HasTrait<NoSideEffect>()) continue;
bool use_empty = true;
for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty();
for (uint32_t i = 0; i < op->num_results(); ++i) {
use_empty &= op->result(i).use_empty();
}
// TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it);
if (use_empty && op->name() != "pd.fetch") {
erased_op.push_back(op);
}
}
for (auto ep : erased_op) block->erase(ep);
for (auto *op : erased_op) {
if (op->dyn_cast<ir::GetParameterOp>()) {
// Delete parameter from program.
ir::GetParameterOp get_parameter_op =
op->dyn_cast<ir::GetParameterOp>();
get_parameter_op->GetParentProgram()->parameters().erase(
get_parameter_op->attributes()
.at(get_parameter_op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data());
}
block->erase(*op);
}
}
bool CanApplyOn(ir::Operation *op) const override {
......@@ -57,6 +71,8 @@ class DcePass : public ir::Pass {
namespace ir {
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }
std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
return std::make_unique<DeadCodeEliminationPass>();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateDeadCodeEliminationPass();
} // namespace ir
......@@ -55,8 +55,8 @@ struct PassInfo {
std::string name;
// opt_level=0: the basic pass which framework need.
// opt_level=1: the fusion logical pass.
// opt_level=2: constant fold, cse, memory optimize, etc.
// opt_level=1: constant fold, cse, memory optimize, etc.
// opt_level=2: the fusion logical pass.
// opt_level=3: layout, etc.
uint8_t opt_level;
......
......@@ -110,7 +110,8 @@ class IR_API PassManager {
// TODO(liuyuanle): Add flags to control printing behavior.
};
void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config);
void EnableIRPrinting(std::unique_ptr<IRPrinterOption> option =
std::make_unique<IRPrinterOption>());
void EnablePassTiming(bool print_module = true);
......
......@@ -26,6 +26,7 @@
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
......@@ -148,6 +149,8 @@ class IR_API Pattern {
const PatternBenefit benefit_;
IrContext* context_;
// A list of the potential operations that may be generated when rewriting an
// op with this pattern.
std::vector<OpInfo> generated_ops_;
std::string debug_name_;
......@@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern {
virtual void Rewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT
throw(
IR_THROW(
"need to implement either MatchAndRewrite or one of the rewrite "
"functions.");
}
virtual bool Match(Operation* op) const {
throw("need to implement either MatchAndRewrite or Match.");
IR_THROW("need to implement either MatchAndRewrite or Match.");
return false;
}
......@@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
virtual void Rewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT
throw("must override Rewrite or MatchAndRewrite");
IR_THROW("must override Rewrite or MatchAndRewrite");
}
virtual bool Match(SourceOp op) const {
throw("must override Match or MatchAndRewrite");
IR_THROW("must override Match or MatchAndRewrite");
}
virtual bool MatchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT
......
......@@ -15,6 +15,10 @@
#include <gtest/gtest.h>
#include "glog/logging.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
......@@ -125,7 +129,7 @@ class TestPass : public ir::Pass {
pass_state().preserved_analyses.Preserve<CountOpAnalysis>();
CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
true);
CHECK_EQ(count_op_analysis.count, 4);
CHECK_EQ(count_op_analysis.count, 11);
auto module_op = op->dyn_cast<ir::ModuleOp>();
CHECK_EQ(module_op.operation(), op);
......@@ -143,139 +147,78 @@ class TestPass : public ir::Pass {
}
};
TEST(pass_manager, PassManager) {
//
// TODO(liuyuanle): remove test code other than pass manager
//
// (1) Init environment.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *builtin_dialect =
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
builtin_dialect->RegisterOp<AddOp>();
ir::Dialect *paddle_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// (2) Create an empty program object
ir::Program program(ctx);
// (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {2, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{0, 1, 2}};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
std::vector<float> data_a = {1, 2, 3, 4};
std::unique_ptr<ir::Parameter> parameter_a =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_a.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("a", std::move(parameter_a));
EXPECT_EQ(program.parameters_num() == 1, true);
std::vector<float> data_b = {5, 6, 7, 8};
std::unique_ptr<ir::Parameter> parameter_b =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_b.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("b", std::move(parameter_b));
EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
ir::Builder builder(ctx, program.block());
auto op1 = builder.Build<ir::GetParameterOp>("a", dense_tensor_dtype);
EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface =
op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(), data_layout);
EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) {
EXPECT_EQ(*(a_tensor.data<float>() + i), data_a[i]);
}
void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_filter_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 3, 3, 3},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_mean_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp full_variance_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_scale_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_bias_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::Conv2dOp conv2d_op =
builder.Build<paddle::dialect::Conv2dOp>(full_input_op.out(),
full_filter_op.out());
paddle::dialect::BatchNormOp batch_norm_op =
builder.Build<paddle::dialect::BatchNormOp>(conv2d_op.out(),
full_mean_op.out(),
full_variance_op.out(),
full_scale_op.out(),
full_bias_op.out(),
true,
0.9,
1e-6,
"NCHW",
false,
false);
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
auto op2 = builder.Build<ir::GetParameterOp>("b", dense_tensor_dtype);
EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
Interface *b_interface =
op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(), data_layout);
EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) {
EXPECT_EQ(*(b_tensor.data<float>() + i), data_b[i]);
}
auto transpose1_op = builder.Build<paddle::dialect::TransposeOp>(
batch_norm_op.out(), std::vector<int>{0, 2, 3, 1});
// (6) Def c = AddOp(a, b), execute this op.
auto op3 =
builder.Build<AddOp>(op1->result(0), op2->result(0), dense_tensor_dtype);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace()));
phi::DenseTensor c_tensor =
phi::Add<float, phi::CPUContext>(*dev_ctx, a_tensor, b_tensor);
std::shared_ptr<paddle::framework::Variable> variable_c =
std::make_shared<paddle::framework::Variable>();
auto *dst_tensor = variable_c->GetMutable<phi::DenseTensor>();
*dst_tensor = c_tensor;
EXPECT_EQ(dst_tensor->numel(), b_tensor.numel());
EXPECT_EQ(dst_tensor->dims(), b_tensor.dims());
EXPECT_EQ(dst_tensor->dtype(), b_tensor.dtype());
EXPECT_EQ(dst_tensor->layout(), b_tensor.layout());
EXPECT_EQ(dst_tensor->lod(), b_tensor.lod());
EXPECT_EQ(dst_tensor->offset(), b_tensor.offset());
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
}
auto transpose2_op = builder.Build<paddle::dialect::TransposeOp>(
transpose1_op.out(), std::vector<int>{0, 3, 1, 2});
// (7) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
c_interface->VariableToParameter(variable_c.get());
EXPECT_EQ(parameter_c->type(), dense_tensor_dtype);
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i),
*(static_cast<float *>(parameter_c->data()) + i));
}
program.SetParameter("c", std::move(parameter_c));
builder.Build<paddle::dialect::FetchOp>(transpose2_op.out(), "out", 0);
}
// (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true);
TEST(pass_manager, PassManager) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);
//
// TODO(liuyuanle): remove the code above.
//
EXPECT_EQ(program.block()->size(), 11u);
// (9) Test pass manager for program.
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
// pm.EnableIRPrinting();
pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
[](ir::Pass *pass, ir::Operation *op) {
return pass->name() == "TestPass";
......
cc_test_old(
pattern_rewrite_test
SRCS
pattern_rewrite_test.cc
DEPS
ir
pd_dialect
transform_general_functions
gtest)
set(PATTERN_REWRITE_TEST_DEPS _constant_folding_pass
transform_general_functions gtest pd_dialect ir)
if(WITH_DISTRIBUTE)
set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor)
endif()
cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS
${PATTERN_REWRITE_TEST_DEPS})
......@@ -15,12 +15,15 @@
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <memory>
#include <numeric>
#include <sstream>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
......@@ -39,7 +42,7 @@
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dce.h"
#include "paddle/phi/core/kernel_registry.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
......@@ -56,6 +59,18 @@
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/infermeta/multiary.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(divide, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(reshape, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(fetch, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(conv2d, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(transpose, CPU, ALL_LAYOUT);
// Define op1.
class Operation1 : public ir::Op<Operation1> {
public:
......@@ -197,7 +212,7 @@ class RedundantTransposeFusePattern
bool MatchAndRewrite(paddle::dialect::TransposeOp op,
ir::PatternRewriter &rewriter) const override {
auto prev_op = ir::GetDefiningOpForInput<0>(op);
auto prev_op = ir::GetDefiningOpForInput(op, 0);
std::vector<int> axis_last = GetAxis(op);
auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>();
if (prev_trans_op) {
......@@ -207,7 +222,7 @@ class RedundantTransposeFusePattern
auto new_perm = GetPerm(axis_first, axis_last);
rewriter.SetInsertionPoint(op);
auto new_transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir::GetDefiningOpForInput<0>(prev_trans_op)->result(0), new_perm);
ir::GetDefiningOpForInput(prev_trans_op, 0)->result(0), new_perm);
rewriter.ReplaceOp(op, {new_transpose_op.out()});
return true;
}
......@@ -249,7 +264,7 @@ class Conv2dBnFusePattern
ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be batch_norm.
paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>();
ir::GetDefiningOpForInput(op, 0)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out();
......@@ -320,7 +335,6 @@ class Conv2dBnFusePattern
std::string data_format =
new_conv2d_op.attribute<ir::StrAttribute>("data_format").data();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[0] = new_conv2d_out_shape[0];
new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
......@@ -895,7 +909,7 @@ class Conv2dAddFusePattern
ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be add.
paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>();
ir::GetDefiningOpForInput(op, 0)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out();
......@@ -929,12 +943,10 @@ class Conv2dAddFusePattern
conv2d_attributes.at("dilations"),
conv2d_attributes.at("groups"),
conv2d_attributes.at("data_format"),
ir::StrAttribute::get(ir::IrContext::Instance(), "identity"),
ir::BoolAttribute::get(ir::IrContext::Instance(), true),
ir::ArrayAttribute::get(ir::IrContext::Instance(),
std::vector<ir::Attribute>()),
ir::Int32Attribute::get(ir::IrContext::Instance(), int32_t(0)),
};
rewriter.str_attr("identity"),
rewriter.bool_attr(true),
rewriter.array_attr(std::vector<ir::Attribute>{}),
rewriter.int32_attr(0)};
ir::AttributeMap conv2d_fusion_attributes;
for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) {
conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i];
......@@ -943,7 +955,7 @@ class Conv2dAddFusePattern
ir::OpResult tmpResidual;
auto conv2d_fuse_op = rewriter.Build<paddle::dialect::Conv2dFusionOpTest>(
ir::GetDefiningOpForInput<0>(conv2d_op)->result(0),
ir::GetDefiningOpForInput(conv2d_op, 0)->result(0),
conv2d_filter_result,
bias,
tmpResidual,
......@@ -956,27 +968,48 @@ class Conv2dAddFusePattern
class TestPass : public ir::Pass {
public:
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context());
ps.Add<RedundantTransposeFusePattern>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context());
ps.Add<Conv2dAddFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
bool Initialize(ir::IrContext *context) override {
ir::RewritePatternSet ps(context);
ps.Add<RedundantTransposeFusePattern>(context);
auto conv_bn_pattern = std::make_unique<Conv2dBnFusePattern>(
context,
1,
std::vector<std::string>{paddle::dialect::FullOp::name(),
paddle::dialect::AddOp::name(),
paddle::dialect::SqrtOp::name(),
paddle::dialect::DivideOp::name(),
paddle::dialect::ReshapeOp::name(),
paddle::dialect::MultiplyOp::name(),
paddle::dialect::SubtractOp::name(),
paddle::dialect::Conv2dOp::name()});
LOG(INFO) << "Conv2dBnFusePattern will generate the following operations: ";
for (auto op_info : conv_bn_pattern->generated_ops()) {
LOG(INFO) << "--- " << op_info.name();
}
ps.Add(std::move(conv_bn_pattern));
patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
}
void Run(ir::Operation *op) override {
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg);
ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}
bool CanApplyOn(ir::Operation *op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
private:
ir::FrozenRewritePatternSet patterns_;
};
void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 3, 16, 16},
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
......@@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDcePass());
program.Print(std::cout);
std::cout << std::endl;
pm.Run(&program);
LOG(INFO) << "After Pass.";
program.Print(std::cout);
std::cout << std::endl;
pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateDeadCodeEliminationPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();
// pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
// [](ir::Pass *pass, ir::Operation *op) {
// return pass->name() == "ConstantFoldingPass";
// },
// [](ir::Pass *pass, ir::Operation *op) {
// return pass->name() == "ConstantFoldingPass";
// },
// true,
// true));
CHECK_EQ(pm.Run(&program), true);
}
......@@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) {
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
......@@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) {
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
......@@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) {
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册