未验证 提交 4905a247 编写于 作者: Y Yuanle Liu 提交者: GitHub

[PASS] add constant folding pass (#55099)

上级 df311526
...@@ -139,14 +139,12 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -139,14 +139,12 @@ void HandleForSpecialOp(ir::Operation* op,
if (op_name == "pd.fetch") { if (op_name == "pd.fetch") {
// fetch is a very special op, with no output // fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:"; VLOG(6) << "Handle for pd.fetch:";
for (size_t i = 0; i < input_num; ++i) { auto var = scope->Var("fetch");
auto var = scope->Var("fetch"); VLOG(6) << "Create var: fetch in scope " << scope;
VLOG(6) << "Create var: fetch in scope " << scope; auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
auto fetch_list = var->GetMutable<paddle::framework::FetchList>(); int index =
int index = op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); fetch_list->resize(index + 1);
fetch_list->resize(index + 1);
}
} }
if (op_name == "pd.feed") { if (op_name == "pd.feed") {
......
...@@ -7,3 +7,9 @@ cc_library( ...@@ -7,3 +7,9 @@ cc_library(
pd_op_to_kernel_pass pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc SRCS pd_op_to_kernel_pass.cc
DEPS phi_utils pd_interface pd_trait ir) DEPS phi_utils pd_interface pd_trait ir)
cc_library(
_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor phi pd_op_to_kernel_pass transform_general_functions
ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace {
class ConstantFoldingPattern : public ir::RewritePattern {
public:
ConstantFoldingPattern(ir::IrContext* context,
ir::PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) {
}
bool Match(ir::Operation* op) const override {
// TODO(liuyuanle): Use trait to improve robustness.
if (op->dyn_cast<ir::GetParameterOp>() ||
op->dyn_cast<ir::SetParameterOp>() ||
op->dyn_cast<paddle::dialect::FetchOp>())
return false;
// Inputs must come from get parameter op.
for (uint32_t i = 0; i < op->num_operands(); ++i)
if (ir::GetDefiningOpForInput(op, i)->dyn_cast<ir::GetParameterOp>() ==
nullptr)
return false;
return true;
}
void Rewrite(ir::Operation* op,
ir::PatternRewriter& rewriter) const override { // NOLINT
ir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op);
// Execute program
paddle::framework::interpreter::ExecutionConfig exe_config;
exe_config.create_local_scope = false;
paddle::framework::InterpreterCore core(
phi::CPUPlace{},
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_,
exe_config);
paddle::framework::FetchList fetch_list = core.Run({});
// TODO(liuyuanle): Support multiple output.
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
std::unique_ptr<ir::Parameter> parameter = std::make_unique<ir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op->result(0).type());
std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
auto* param_var = scope_.Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
*param_tensor = out_tensor;
program->SetParameter(param_name, std::move(parameter));
// rewriter.SetInsertionPoint(op);
auto get_parameter_op =
rewriter.Build<ir::GetParameterOp>(param_name, op->result(0).type());
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}
private:
std::unique_ptr<ir::Program> BuildProgramFromOperation(
ir::Operation* op) const {
auto program = std::make_unique<ir::Program>(ir_context());
ir::Builder builder = ir::Builder(ir_context(), program->block());
// prepare op inputs
std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));
auto [param_name, param] = ir::GetParameterFromValue(op->operand(i));
program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param));
auto* param_var = scope_.FindVar(param_name);
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
auto get_parameter_op =
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type());
op_inputs.push_back(get_parameter_op->result(0));
}
// prepare op outputs
std::vector<ir::Type> output_types;
for (uint32_t i = 0; i < op->num_results(); i++) {
output_types.push_back(op->result(i).type());
}
auto* temp_op =
builder.Build(op_inputs, op->attributes(), output_types, op->info());
// TODO(liuyuanle): Support multiple output.
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ(
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's output must be a dense tensor type."));
builder.Build<paddle::dialect::FetchOp>(
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
// }
return program;
}
private:
static size_t suffix_;
static paddle::framework::Scope scope_;
};
size_t ConstantFoldingPattern::suffix_ = 0;
paddle::framework::Scope ConstantFoldingPattern::scope_ = {};
class ConstantFoldingPass : public ir::Pass {
public:
// TODO(liuyuanle): Naming convention for pass.
ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {}
bool Initialize(ir::IrContext* context) override {
ir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(context);
patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
}
void Run(ir::Operation* op) override {
ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}
bool CanApplyOn(ir::Operation* op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
private:
ir::FrozenRewritePatternSet patterns_;
};
} // namespace
namespace ir {
std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
}
} // namespace ir
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
#include "paddle/ir/core/dll_decl.h" #include "paddle/ir/core/dll_decl.h"
namespace ir { namespace ir {
class Pass; class Pass;
IR_API std::unique_ptr<Pass> CreateDcePass(); IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
} // namespace ir } // namespace ir
...@@ -17,22 +17,28 @@ ...@@ -17,22 +17,28 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
namespace ir { namespace ir {
ir::Parameter* GetParameterFromValue(ir::Value value) { std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) {
ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast<ir::GetParameterOp>(); ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast<ir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
op, op,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Value must be a weight from a GetParameter op.")); "Value must be a weight from a GetParameter op."));
ir::Program* program = op->GetParentProgram(); ir::Program* program = op->GetParentProgram();
PADDLE_ENFORCE_NOT_NULL(
program, phi::errors::InvalidArgument("Program should not be null."));
std::string name = op->attributes() std::string name = op->attributes()
.at(op.attributes_name[0]) .at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .data();
return program->GetParameter(name); ir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
return {name, param};
} }
const phi::DDim& GetShapeFromValue(ir::Value value) { const phi::DDim& GetShapeFromValue(ir::Value value) {
...@@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) { ...@@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) {
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims(); return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
} }
ir::Type GetDataTypeFromValue(ir::Value value) {
// TODO(dev): Support other types like DenseTensor.
PADDLE_ENFORCE_EQ(
value.type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument("Value's type must be a DenseTensorType."));
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
}
Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
PADDLE_ENFORCE_EQ(
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp();
}
Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
PADDLE_ENFORCE_EQ(
index < op->num_results(),
true,
phi::errors::InvalidArgument("Output op result's index must be valid."));
return op->result(index).first_use().owner();
}
} // namespace ir } // namespace ir
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h" #include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -24,15 +25,16 @@ ...@@ -24,15 +25,16 @@
namespace ir { namespace ir {
/** /**
* @brief Get the parameter from a value. * @brief Get the [name, parameter] pair of pararmeter from a value.
* *
* @note The value must be a output of a GetParameterOp. * @note The value must be a output of a GetParameterOp.
* *
* @param ir::Value * @param ir::Value
* *
* @return ir::Parameter* * @return std::pair<std::string, ir::Parameter*>
*/ */
ir::Parameter* GetParameterFromValue(ir::Value value);
std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value);
/** /**
* @brief Get tensor's shape from a value. * @brief Get tensor's shape from a value.
...@@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value); ...@@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value);
*/ */
const phi::DDim& GetShapeFromValue(ir::Value value); const phi::DDim& GetShapeFromValue(ir::Value value);
/**
* @brief Get tensor's data type from a value.
*
* @param ir::Value
*
* @return ir::Type
*/
ir::Type GetDataTypeFromValue(ir::Value value);
/** /**
* @brief Get an operation that defines the specific input of the operation. * @brief Get an operation that defines the specific input of the operation.
* *
* @param Operation* * @param Operation* pointer to an operation
* @param uint32_t index of operand of the operation
* *
* @return Operation* * @return Operation*
*/ */
template <uint32_t Index = 0> Operation* GetDefiningOpForInput(Operation* op, uint32_t index);
Operation* GetDefiningOpForInput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(Index).GetDefiningOp();
}
/** /**
* @brief Get an operation that is the first to use the specific output of the * @brief Get an operation that is the first to use the specific output of the
* operation. * operation.
* *
* @param Operation* * @param Operation* pointer to an operation
* * @param uint32_t index of result of the operation
* @return Operation* * @return Operation*
*/ */
template <uint32_t Index = 0> Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index);
Operation* GetFirstUseOperationForOutput(Operation* op) {
PADDLE_ENFORCE_EQ(
Index < op->num_results(),
true,
phi::errors::InvalidArgument("Output op result's index must be valid."));
return op->result(Index).first_use().owner();
}
} // namespace ir } // namespace ir
...@@ -37,7 +37,7 @@ endfunction() ...@@ -37,7 +37,7 @@ endfunction()
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(pass) add_subdirectory(pass)
add_subdirectory(pattern_rewrite) add_subdirectory(pattern_rewrite)
add_subdirectory(transforms) add_subdirectory(builtin_transforms)
if(WIN32) if(WIN32)
if(WITH_SHARED_IR) if(WITH_SHARED_IR)
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/ir/transforms/dce.h" #include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h"
#include <memory>
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
namespace { namespace {
...@@ -22,30 +23,43 @@ namespace { ...@@ -22,30 +23,43 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be // TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass. // removed by dce pass.
// Now just a naive implementation. // Now just a naive implementation.
class DcePass : public ir::Pass { class DeadCodeEliminationPass : public ir::Pass {
public: public:
DcePass() : ir::Pass("DcePass", 0) {} DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {}
void Run(ir::Operation *op) override { void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>(); auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DcePass should run on module op."); IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block(); auto *block = module_op.block();
std::vector<ir::Operation> erased_op; std::vector<ir::Operation *> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
auto &op = *it;
// TODO(wilber): Support NoSideEffect trait. // TODO(wilber): Support NoSideEffect trait.
// if (!(*it)->HasTrait<NoSideEffect>()) continue; // if (!op->HasTrait<NoSideEffect>()) continue;
bool use_empty = true; bool use_empty = true;
for (uint32_t i = 0; i < (*it)->num_results(); ++i) { for (uint32_t i = 0; i < op->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty(); use_empty &= op->result(i).use_empty();
} }
// TODO(wilber): Support Terminator trait. // TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") { if (use_empty && op->name() != "pd.fetch") {
erased_op.push_back(**it); erased_op.push_back(op);
} }
} }
for (auto ep : erased_op) block->erase(ep); for (auto *op : erased_op) {
if (op->dyn_cast<ir::GetParameterOp>()) {
// Delete parameter from program.
ir::GetParameterOp get_parameter_op =
op->dyn_cast<ir::GetParameterOp>();
get_parameter_op->GetParentProgram()->parameters().erase(
get_parameter_op->attributes()
.at(get_parameter_op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data());
}
block->erase(*op);
}
} }
bool CanApplyOn(ir::Operation *op) const override { bool CanApplyOn(ir::Operation *op) const override {
...@@ -57,6 +71,8 @@ class DcePass : public ir::Pass { ...@@ -57,6 +71,8 @@ class DcePass : public ir::Pass {
namespace ir { namespace ir {
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); } std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
return std::make_unique<DeadCodeEliminationPass>();
}
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateDeadCodeEliminationPass();
} // namespace ir
...@@ -55,8 +55,8 @@ struct PassInfo { ...@@ -55,8 +55,8 @@ struct PassInfo {
std::string name; std::string name;
// opt_level=0: the basic pass which framework need. // opt_level=0: the basic pass which framework need.
// opt_level=1: the fusion logical pass. // opt_level=1: constant fold, cse, memory optimize, etc.
// opt_level=2: constant fold, cse, memory optimize, etc. // opt_level=2: the fusion logical pass.
// opt_level=3: layout, etc. // opt_level=3: layout, etc.
uint8_t opt_level; uint8_t opt_level;
......
...@@ -110,7 +110,8 @@ class IR_API PassManager { ...@@ -110,7 +110,8 @@ class IR_API PassManager {
// TODO(liuyuanle): Add flags to control printing behavior. // TODO(liuyuanle): Add flags to control printing behavior.
}; };
void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config); void EnableIRPrinting(std::unique_ptr<IRPrinterOption> option =
std::make_unique<IRPrinterOption>());
void EnablePassTiming(bool print_module = true); void EnablePassTiming(bool print_module = true);
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dll_decl.h" #include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
...@@ -148,6 +149,8 @@ class IR_API Pattern { ...@@ -148,6 +149,8 @@ class IR_API Pattern {
const PatternBenefit benefit_; const PatternBenefit benefit_;
IrContext* context_; IrContext* context_;
// A list of the potential operations that may be generated when rewriting an
// op with this pattern.
std::vector<OpInfo> generated_ops_; std::vector<OpInfo> generated_ops_;
std::string debug_name_; std::string debug_name_;
...@@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern { ...@@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern {
virtual void Rewrite(Operation* op, virtual void Rewrite(Operation* op,
PatternRewriter& rewriter) const { // NOLINT PatternRewriter& rewriter) const { // NOLINT
throw( IR_THROW(
"need to implement either MatchAndRewrite or one of the rewrite " "need to implement either MatchAndRewrite or one of the rewrite "
"functions."); "functions.");
} }
virtual bool Match(Operation* op) const { virtual bool Match(Operation* op) const {
throw("need to implement either MatchAndRewrite or Match."); IR_THROW("need to implement either MatchAndRewrite or Match.");
return false; return false;
} }
...@@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { ...@@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
virtual void Rewrite(SourceOp op, virtual void Rewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT PatternRewriter& rewriter) const { // NOLINT
throw("must override Rewrite or MatchAndRewrite"); IR_THROW("must override Rewrite or MatchAndRewrite");
} }
virtual bool Match(SourceOp op) const { virtual bool Match(SourceOp op) const {
throw("must override Match or MatchAndRewrite"); IR_THROW("must override Match or MatchAndRewrite");
} }
virtual bool MatchAndRewrite(SourceOp op, virtual bool MatchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const { // NOLINT PatternRewriter& rewriter) const { // NOLINT
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "glog/logging.h" #include "glog/logging.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
...@@ -125,7 +129,7 @@ class TestPass : public ir::Pass { ...@@ -125,7 +129,7 @@ class TestPass : public ir::Pass {
pass_state().preserved_analyses.Preserve<CountOpAnalysis>(); pass_state().preserved_analyses.Preserve<CountOpAnalysis>();
CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(), CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
true); true);
CHECK_EQ(count_op_analysis.count, 4); CHECK_EQ(count_op_analysis.count, 11);
auto module_op = op->dyn_cast<ir::ModuleOp>(); auto module_op = op->dyn_cast<ir::ModuleOp>();
CHECK_EQ(module_op.operation(), op); CHECK_EQ(module_op.operation(), op);
...@@ -143,139 +147,78 @@ class TestPass : public ir::Pass { ...@@ -143,139 +147,78 @@ class TestPass : public ir::Pass {
} }
}; };
TEST(pass_manager, PassManager) { void BuildProgram(ir::Builder &builder) { // NOLINT
// paddle::dialect::FullOp full_input_op =
// TODO(liuyuanle): remove test code other than pass manager builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
// 1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_filter_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 3, 3, 3},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_mean_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp full_variance_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_scale_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::FullOp full_bias_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::Conv2dOp conv2d_op =
builder.Build<paddle::dialect::Conv2dOp>(full_input_op.out(),
full_filter_op.out());
paddle::dialect::BatchNormOp batch_norm_op =
builder.Build<paddle::dialect::BatchNormOp>(conv2d_op.out(),
full_mean_op.out(),
full_variance_op.out(),
full_scale_op.out(),
full_bias_op.out(),
true,
0.9,
1e-6,
"NCHW",
false,
false);
auto transpose1_op = builder.Build<paddle::dialect::TransposeOp>(
batch_norm_op.out(), std::vector<int>{0, 2, 3, 1});
auto transpose2_op = builder.Build<paddle::dialect::TransposeOp>(
transpose1_op.out(), std::vector<int>{0, 3, 1, 2});
builder.Build<paddle::dialect::FetchOp>(transpose2_op.out(), "out", 0);
}
// (1) Init environment. TEST(pass_manager, PassManager) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *builtin_dialect = ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
builtin_dialect->RegisterOp<AddOp>();
ir::Dialect *paddle_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// (2) Create an empty program object
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);
// (3) Create a float32 DenseTensor Parameter and save into Program EXPECT_EQ(program.block()->size(), 11u);
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {2, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{0, 1, 2}};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
std::vector<float> data_a = {1, 2, 3, 4};
std::unique_ptr<ir::Parameter> parameter_a =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_a.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("a", std::move(parameter_a));
EXPECT_EQ(program.parameters_num() == 1, true);
std::vector<float> data_b = {5, 6, 7, 8};
std::unique_ptr<ir::Parameter> parameter_b =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_b.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("b", std::move(parameter_b));
EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
ir::Builder builder(ctx, program.block());
auto op1 = builder.Build<ir::GetParameterOp>("a", dense_tensor_dtype);
EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface =
op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(), data_layout);
EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) {
EXPECT_EQ(*(a_tensor.data<float>() + i), data_a[i]);
}
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
auto op2 = builder.Build<ir::GetParameterOp>("b", dense_tensor_dtype);
EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
Interface *b_interface =
op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(), data_layout);
EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) {
EXPECT_EQ(*(b_tensor.data<float>() + i), data_b[i]);
}
// (6) Def c = AddOp(a, b), execute this op.
auto op3 =
builder.Build<AddOp>(op1->result(0), op2->result(0), dense_tensor_dtype);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace()));
phi::DenseTensor c_tensor =
phi::Add<float, phi::CPUContext>(*dev_ctx, a_tensor, b_tensor);
std::shared_ptr<paddle::framework::Variable> variable_c =
std::make_shared<paddle::framework::Variable>();
auto *dst_tensor = variable_c->GetMutable<phi::DenseTensor>();
*dst_tensor = c_tensor;
EXPECT_EQ(dst_tensor->numel(), b_tensor.numel());
EXPECT_EQ(dst_tensor->dims(), b_tensor.dims());
EXPECT_EQ(dst_tensor->dtype(), b_tensor.dtype());
EXPECT_EQ(dst_tensor->layout(), b_tensor.layout());
EXPECT_EQ(dst_tensor->lod(), b_tensor.lod());
EXPECT_EQ(dst_tensor->offset(), b_tensor.offset());
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
}
// (7) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
c_interface->VariableToParameter(variable_c.get());
EXPECT_EQ(parameter_c->type(), dense_tensor_dtype);
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i),
*(static_cast<float *>(parameter_c->data()) + i));
}
program.SetParameter("c", std::move(parameter_c));
// (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true);
//
// TODO(liuyuanle): remove the code above.
//
// (9) Test pass manager for program. // (9) Test pass manager for program.
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
// pm.EnableIRPrinting();
pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>( pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
[](ir::Pass *pass, ir::Operation *op) { [](ir::Pass *pass, ir::Operation *op) {
return pass->name() == "TestPass"; return pass->name() == "TestPass";
......
cc_test_old( set(PATTERN_REWRITE_TEST_DEPS _constant_folding_pass
pattern_rewrite_test transform_general_functions gtest pd_dialect ir)
SRCS
pattern_rewrite_test.cc if(WITH_DISTRIBUTE)
DEPS set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor)
ir endif()
pd_dialect
transform_general_functions cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS
gtest) ${PATTERN_REWRITE_TEST_DEPS})
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdint> #include <cstdint>
#include <iostream> #include <iostream>
#include <memory>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h" #include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
...@@ -39,7 +42,7 @@ ...@@ -39,7 +42,7 @@
#include "paddle/ir/pattern_rewrite/pattern_applicator.h" #include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dce.h" #include "paddle/phi/core/kernel_registry.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt. // paddle/fluid/ir/dialect/CMakeLists.txt.
...@@ -56,6 +59,18 @@ ...@@ -56,6 +59,18 @@
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(divide, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(reshape, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(fetch, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(conv2d, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(transpose, CPU, ALL_LAYOUT);
// Define op1. // Define op1.
class Operation1 : public ir::Op<Operation1> { class Operation1 : public ir::Op<Operation1> {
public: public:
...@@ -197,7 +212,7 @@ class RedundantTransposeFusePattern ...@@ -197,7 +212,7 @@ class RedundantTransposeFusePattern
bool MatchAndRewrite(paddle::dialect::TransposeOp op, bool MatchAndRewrite(paddle::dialect::TransposeOp op,
ir::PatternRewriter &rewriter) const override { ir::PatternRewriter &rewriter) const override {
auto prev_op = ir::GetDefiningOpForInput<0>(op); auto prev_op = ir::GetDefiningOpForInput(op, 0);
std::vector<int> axis_last = GetAxis(op); std::vector<int> axis_last = GetAxis(op);
auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>(); auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>();
if (prev_trans_op) { if (prev_trans_op) {
...@@ -207,7 +222,7 @@ class RedundantTransposeFusePattern ...@@ -207,7 +222,7 @@ class RedundantTransposeFusePattern
auto new_perm = GetPerm(axis_first, axis_last); auto new_perm = GetPerm(axis_first, axis_last);
rewriter.SetInsertionPoint(op); rewriter.SetInsertionPoint(op);
auto new_transpose_op = rewriter.Build<paddle::dialect::TransposeOp>( auto new_transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir::GetDefiningOpForInput<0>(prev_trans_op)->result(0), new_perm); ir::GetDefiningOpForInput(prev_trans_op, 0)->result(0), new_perm);
rewriter.ReplaceOp(op, {new_transpose_op.out()}); rewriter.ReplaceOp(op, {new_transpose_op.out()});
return true; return true;
} }
...@@ -249,7 +264,7 @@ class Conv2dBnFusePattern ...@@ -249,7 +264,7 @@ class Conv2dBnFusePattern
ir::PatternRewriter &rewriter) const override { // NOLINT ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be batch_norm. // The next op should be batch_norm.
paddle::dialect::Conv2dOp conv2d_op = paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>(); ir::GetDefiningOpForInput(op, 0)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false; if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out(); ir::OpResult conv2d_out = conv2d_op.out();
...@@ -320,7 +335,6 @@ class Conv2dBnFusePattern ...@@ -320,7 +335,6 @@ class Conv2dBnFusePattern
std::string data_format = std::string data_format =
new_conv2d_op.attribute<ir::StrAttribute>("data_format").data(); new_conv2d_op.attribute<ir::StrAttribute>("data_format").data();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[0] = new_conv2d_out_shape[0];
new_bias_new_shape[1] = new_conv2d_out_shape[1]; new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op = paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(), rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
...@@ -895,7 +909,7 @@ class Conv2dAddFusePattern ...@@ -895,7 +909,7 @@ class Conv2dAddFusePattern
ir::PatternRewriter &rewriter) const override { // NOLINT ir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be add. // The next op should be add.
paddle::dialect::Conv2dOp conv2d_op = paddle::dialect::Conv2dOp conv2d_op =
ir::GetDefiningOpForInput(op)->dyn_cast<paddle::dialect::Conv2dOp>(); ir::GetDefiningOpForInput(op, 0)->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false; if (!conv2d_op) return false;
ir::OpResult conv2d_out = conv2d_op.out(); ir::OpResult conv2d_out = conv2d_op.out();
...@@ -929,12 +943,10 @@ class Conv2dAddFusePattern ...@@ -929,12 +943,10 @@ class Conv2dAddFusePattern
conv2d_attributes.at("dilations"), conv2d_attributes.at("dilations"),
conv2d_attributes.at("groups"), conv2d_attributes.at("groups"),
conv2d_attributes.at("data_format"), conv2d_attributes.at("data_format"),
ir::StrAttribute::get(ir::IrContext::Instance(), "identity"), rewriter.str_attr("identity"),
ir::BoolAttribute::get(ir::IrContext::Instance(), true), rewriter.bool_attr(true),
ir::ArrayAttribute::get(ir::IrContext::Instance(), rewriter.array_attr(std::vector<ir::Attribute>{}),
std::vector<ir::Attribute>()), rewriter.int32_attr(0)};
ir::Int32Attribute::get(ir::IrContext::Instance(), int32_t(0)),
};
ir::AttributeMap conv2d_fusion_attributes; ir::AttributeMap conv2d_fusion_attributes;
for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) { for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) {
conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i]; conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i];
...@@ -943,7 +955,7 @@ class Conv2dAddFusePattern ...@@ -943,7 +955,7 @@ class Conv2dAddFusePattern
ir::OpResult tmpResidual; ir::OpResult tmpResidual;
auto conv2d_fuse_op = rewriter.Build<paddle::dialect::Conv2dFusionOpTest>( auto conv2d_fuse_op = rewriter.Build<paddle::dialect::Conv2dFusionOpTest>(
ir::GetDefiningOpForInput<0>(conv2d_op)->result(0), ir::GetDefiningOpForInput(conv2d_op, 0)->result(0),
conv2d_filter_result, conv2d_filter_result,
bias, bias,
tmpResidual, tmpResidual,
...@@ -956,27 +968,48 @@ class Conv2dAddFusePattern ...@@ -956,27 +968,48 @@ class Conv2dAddFusePattern
class TestPass : public ir::Pass { class TestPass : public ir::Pass {
public: public:
TestPass() : ir::Pass("TestPass", 1) {} TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context());
ps.Add<RedundantTransposeFusePattern>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context());
ps.Add<Conv2dAddFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); bool Initialize(ir::IrContext *context) override {
ir::RewritePatternSet ps(context);
ps.Add<RedundantTransposeFusePattern>(context);
auto conv_bn_pattern = std::make_unique<Conv2dBnFusePattern>(
context,
1,
std::vector<std::string>{paddle::dialect::FullOp::name(),
paddle::dialect::AddOp::name(),
paddle::dialect::SqrtOp::name(),
paddle::dialect::DivideOp::name(),
paddle::dialect::ReshapeOp::name(),
paddle::dialect::MultiplyOp::name(),
paddle::dialect::SubtractOp::name(),
paddle::dialect::Conv2dOp::name()});
LOG(INFO) << "Conv2dBnFusePattern will generate the following operations: ";
for (auto op_info : conv_bn_pattern->generated_ops()) {
LOG(INFO) << "--- " << op_info.name();
}
ps.Add(std::move(conv_bn_pattern));
patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
}
void Run(ir::Operation *op) override {
ir::GreedyRewriteConfig cfg; ir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true; cfg.use_top_down_traversal = true;
cfg.max_iterations = 10; cfg.max_iterations = 10;
ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg); ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
} }
bool CanApplyOn(ir::Operation *op) const override { bool CanApplyOn(ir::Operation *op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0; return op->name() == "builtin.module" && op->num_regions() > 0;
} }
private:
ir::FrozenRewritePatternSet patterns_;
}; };
void BuildProgram(ir::Builder &builder) { // NOLINT void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_input_op = paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 3, 16, 16}, builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
1.5, 1.5,
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
phi::CPUPlace()); phi::CPUPlace());
...@@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) { ...@@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDcePass()); pm.AddPass(ir::CreateConstantFoldingPass());
program.Print(std::cout); pm.AddPass(ir::CreateDeadCodeEliminationPass());
std::cout << std::endl; pm.EnablePassTiming();
pm.Run(&program); pm.EnableIRPrinting();
LOG(INFO) << "After Pass."; // pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
program.Print(std::cout); // [](ir::Pass *pass, ir::Operation *op) {
std::cout << std::endl; // return pass->name() == "ConstantFoldingPass";
// },
// [](ir::Pass *pass, ir::Operation *op) {
// return pass->name() == "ConstantFoldingPass";
// },
// true,
// true));
CHECK_EQ(pm.Run(&program), true);
} }
...@@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) { ...@@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
Scope scope; Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope); InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({}); test_core.Run({});
...@@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) { ...@@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
Scope scope; Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope); InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({}); test_core.Run({});
...@@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) { ...@@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
Scope scope; Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope); InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({}); test_core.Run({});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册