未验证 提交 0dc6efaa 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Support Build(GroupPtr) Logic in NewIRCompiler and Add UT (#56960)

* [NewIR]Support Build(GroupOps) in NewIRCompiler and Add UT

* fix unittest
上级 7daffbf8
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/cinn/hlir/framework/new_ir/utils.h"
#include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
...@@ -30,20 +31,26 @@ struct Group { ...@@ -30,20 +31,26 @@ struct Group {
public: public:
explicit Group(const std::vector<::ir::Operation*>& group_ops) explicit Group(const std::vector<::ir::Operation*>& group_ops)
: ops(group_ops) { : ops(group_ops) {
op_pattern_kind = OpPatternKind::kElementWise; Initialize();
fn_name = "fn_"; }
for (auto& op : group_ops) {
fn_name += "_" + op->name(); explicit Group(std::initializer_list<::ir::Operation*> group_ops)
} : ops(group_ops) {
Initialize();
} }
int group_id;
std::string fn_name;
OpPatternKind op_pattern_kind;
std::vector<::ir::Operation*> ops; std::vector<::ir::Operation*> ops;
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::vector<std::string> output_names; std::vector<std::string> output_names;
int group_id;
// FIXME(Aurelius84): This should be refactored with CinnGroupOp private:
OpPatternKind op_pattern_kind; void Initialize() {
std::string fn_name; op_pattern_kind = OpPatternKind::kElementWise;
fn_name = CompatibleInfo::GroupOpsName(ops);
}
}; };
} // namespace newir } // namespace newir
......
...@@ -43,7 +43,7 @@ ir::Tensor GetTensor(const ::ir::Value& value) { ...@@ -43,7 +43,7 @@ ir::Tensor GetTensor(const ::ir::Value& value) {
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>(); auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims()); auto in_shape = phi::vectorize<int>(type_info.dims());
auto dtype = type_info.dtype(); auto dtype = type_info.dtype();
std::string input_id = CompatibleInfo::InputName(value); std::string input_id = CompatibleInfo::ValueName(value);
return lang::CreatePlaceHolder( return lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id); in_shape, utils::ConvertIRType(dtype), input_id);
} }
...@@ -56,15 +56,16 @@ std::vector<ir::Tensor> CollectInputTensor( ...@@ -56,15 +56,16 @@ std::vector<ir::Tensor> CollectInputTensor(
for (auto& operand : op->operands()) { for (auto& operand : op->operands()) {
CHECK(operand); CHECK(operand);
auto in_value = operand.source(); auto in_value = operand.source();
ir::Tensor tensor; VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value);
// NOTE(Aurelius84): Need always to create placeholder for input tensor.
ir::Tensor tensor = details::GetTensor(in_value);
if (!tensor_map->count(in_value)) { if (!tensor_map->count(in_value)) {
tensor = details::GetTensor(in_value);
// record tensor. // record tensor.
(*tensor_map)[in_value] = tensor; (*tensor_map)[in_value] = tensor;
// record func input args // record func input args
if (func_args != nullptr) func_args->push_back(tensor); if (func_args != nullptr) {
} else { func_args->push_back(tensor);
tensor = tensor_map->at(in_value); }
} }
tensors.push_back(tensor); tensors.push_back(tensor);
} }
...@@ -76,7 +77,7 @@ void CollectOutputInfo(const ::ir::Operation* op, ...@@ -76,7 +77,7 @@ void CollectOutputInfo(const ::ir::Operation* op,
std::vector<std::vector<int>>* out_shapes) { std::vector<std::vector<int>>* out_shapes) {
auto op_results = op->results(); auto op_results = op->results();
for (auto& out_value : op_results) { for (auto& out_value : op_results) {
std::string output_id = CompatibleInfo::OutputName(out_value); std::string output_id = CompatibleInfo::ValueName(out_value);
// group->output_names.push_back(output_id); // group->output_names.push_back(output_id);
auto type_info = auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>(); out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
...@@ -265,11 +266,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess( ...@@ -265,11 +266,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
// output arg tensors // output arg tensors
group_func_arg_tensors->push_back(tensor); group_func_arg_tensors->push_back(tensor);
// output args // output args
group->output_names.push_back(tensor->name);
group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
arg_name_set.insert(tensor->buffer->name); arg_name_set.insert(tensor->buffer->name);
} }
} }
if (!done_op_schedule) { if (!done_op_schedule) {
std::unordered_set<std::string> args_set; std::unordered_set<std::string> args_set;
for (auto arg : group_func_args) { for (auto arg : group_func_args) {
...@@ -329,6 +330,8 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps( ...@@ -329,6 +330,8 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
std::vector<ir::Tensor> op_func_arg_tensors = std::vector<ir::Tensor> op_func_arg_tensors =
details::CollectInputTensor(op, group_func_arg_tensors, tensor_map); details::CollectInputTensor(op, group_func_arg_tensors, tensor_map);
VLOG(4) << "input size:" << op_func_arg_tensors.size();
std::string cinn_op_name = CompatibleInfo::OpName(*op); std::string cinn_op_name = CompatibleInfo::OpName(*op);
const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name);
auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op]( auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op](
...@@ -348,6 +351,9 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps( ...@@ -348,6 +351,9 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
} }
} }
VLOG(4) << "group_func_arg_tensors.size(): "
<< group_func_arg_tensors->size();
return func_bodies; return func_bodies;
} }
...@@ -364,7 +370,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower( ...@@ -364,7 +370,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
// set tensor name = operand hash name // set tensor name = operand hash name
auto op_results = op->results(); auto op_results = op->results();
for (const auto& result : op_results) { for (const auto& result : op_results) {
std::string output_id = CompatibleInfo::OutputName(result); std::string output_id = CompatibleInfo::ValueName(result);
cinn_inputs.push_back(common::CINNValue(output_id)); cinn_inputs.push_back(common::CINNValue(output_id));
} }
...@@ -400,6 +406,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower( ...@@ -400,6 +406,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
} }
} }
VLOG(4) << "op_func_arg_tensors.size(): " << op_func_arg_tensors->size();
// 2.Do lower // 2.Do lower
std::string lower_fn_name = CompatibleInfo::OpFuncName(*op); std::string lower_fn_name = CompatibleInfo::OpFuncName(*op);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(lower_fn_name, std::vector<ir::LoweredFunc> funcs = lang::LowerVec(lower_fn_name,
......
...@@ -36,13 +36,8 @@ std::string CompatibleInfo::OpName(const ::ir::Operation& op) { ...@@ -36,13 +36,8 @@ std::string CompatibleInfo::OpName(const ::ir::Operation& op) {
return cinn_op_name; return cinn_op_name;
} }
std::string CompatibleInfo::InputName(const ::ir::Value& value) { std::string CompatibleInfo::ValueName(const ::ir::Value& value) {
return CompatibleInfo::kInputPrefix + return CompatibleInfo::kNamePrefix +
std::to_string(std::hash<::ir::Value>()(value));
}
std::string CompatibleInfo::OutputName(const ::ir::Value& value) {
return CompatibleInfo::kOutputPrefix +
std::to_string(std::hash<::ir::Value>()(value)); std::to_string(std::hash<::ir::Value>()(value));
} }
...@@ -55,10 +50,10 @@ std::string CompatibleInfo::OpFuncName(const ::ir::Operation& op) { ...@@ -55,10 +50,10 @@ std::string CompatibleInfo::OpFuncName(const ::ir::Operation& op) {
std::string CompatibleInfo::GroupOpsName( std::string CompatibleInfo::GroupOpsName(
const std::vector<::ir::Operation*>& ops) { const std::vector<::ir::Operation*>& ops) {
std::string name = "fn_"; std::string name = "fn";
for (auto* op : ops) { for (auto* op : ops) {
std::string op_name = OpName(*op); std::string op_name = OpName(*op);
name += cinn::common::Context::Global().NewName(op_name); name += "_" + cinn::common::Context::Global().NewName(op_name);
} }
return name; return name;
} }
...@@ -69,7 +64,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::ir::Operation& op, ...@@ -69,7 +64,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::ir::Operation& op,
std::unordered_set<std::string> repeat; std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) { for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand_source(i); auto value = op.operand_source(i);
std::string name = CompatibleInfo::InputName(value); std::string name = CompatibleInfo::ValueName(value);
if (!allow_duplicate && repeat.count(name)) { if (!allow_duplicate && repeat.count(name)) {
continue; continue;
} }
...@@ -84,7 +79,7 @@ std::vector<std::string> CompatibleInfo::OutputNames( ...@@ -84,7 +79,7 @@ std::vector<std::string> CompatibleInfo::OutputNames(
std::vector<std::string> names; std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) { for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i); auto value = op.result(i);
std::string name = CompatibleInfo::OutputName(value); std::string name = CompatibleInfo::ValueName(value);
names.push_back(std::move(name)); names.push_back(std::move(name));
} }
return names; return names;
......
...@@ -24,17 +24,14 @@ namespace framework { ...@@ -24,17 +24,14 @@ namespace framework {
namespace newir { namespace newir {
struct CompatibleInfo { struct CompatibleInfo {
static constexpr char* kInputPrefix = "input_"; static constexpr char* kNamePrefix = "var_";
static constexpr char* kOutputPrefix = "output_";
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN. // macros or attempt to unify Op name with Paddle and CINN.
static const std::unordered_map<std::string, std::string> OP_NAMES; static const std::unordered_map<std::string, std::string> OP_NAMES;
static std::string OpName(const ::ir::Operation& op); static std::string OpName(const ::ir::Operation& op);
static std::string InputName(const ::ir::Value& value); static std::string ValueName(const ::ir::Value& value);
static std::string OutputName(const ::ir::Value& value);
static std::string OpFuncName(const ::ir::Operation& op); static std::string OpFuncName(const ::ir::Operation& op);
......
...@@ -35,7 +35,6 @@ std::unique_ptr<Program> NewIRCompiler::Build() { ...@@ -35,7 +35,6 @@ std::unique_ptr<Program> NewIRCompiler::Build() {
++it) { ++it) {
std::vector<::ir::Operation*> ops = {*it}; std::vector<::ir::Operation*> ops = {*it};
groups.push_back(std::make_shared<newir::Group>(ops)); groups.push_back(std::make_shared<newir::Group>(ops));
groups.back()->fn_name = CompatibleInfo::GroupOpsName(groups.back()->ops);
} }
VLOG(4) << "Groups size: " << groups.size(); VLOG(4) << "Groups size: " << groups.size();
return std::move(Build(groups)); return std::move(Build(groups));
...@@ -103,23 +102,20 @@ std::vector<std::unique_ptr<Instruction>> NewIRCompiler::BuildInstructions( ...@@ -103,23 +102,20 @@ std::vector<std::unique_ptr<Instruction>> NewIRCompiler::BuildInstructions(
const std::vector<newir::GroupPtr>& groups) { const std::vector<newir::GroupPtr>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions; std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) { for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *(groups[idx]->ops[0]);
auto& fn_name = groups[idx]->fn_name; auto& fn_name = groups[idx]->fn_name;
auto instr = std::unique_ptr<Instruction>( auto instr =
new Instruction(target_, std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(), scope_.get(),
CompatibleInfo::InputNames(op), groups[idx]->input_names,
CompatibleInfo::OutputNames(op), groups[idx]->output_names,
fn_name)); fn_name));
VLOG(1) << "Lookup kernel name: " << fn_name; VLOG(1) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = compiler_->Lookup(fn_name); auto* fn_ptr = compiler_->Lookup(fn_name);
CHECK(fn_ptr); CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), fn_name); instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), fn_name);
// As some instruction like reduce, will generate more than one kernel. // As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists. // So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name); // SetSubKernels(instr.get(), fn_name);
instr->Finalize(); instr->Finalize();
instructions.push_back(std::move(instr)); instructions.push_back(std::move(instr));
} }
...@@ -131,16 +127,15 @@ std::shared_ptr<Scope> BuildScope(const Target& target, ...@@ -131,16 +127,15 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
std::unordered_set<::ir::Value> visited; std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto create_var = [&](const std::string& name_prefix, ::ir::Value value) { auto create_var = [&](::ir::Value value) {
if (visited.count(value) > 0) return; if (visited.count(value) > 0) return;
visited.emplace(value); visited.emplace(value);
std::string name = std::string name = CompatibleInfo::ValueName(value);
name_prefix + std::to_string(std::hash<::ir::Value>()(value));
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>(); auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name); auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var); auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?
std::vector<Shape::dim_t> shape; std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) { for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i])); shape.push_back(Shape::dim_t(type_info.dims()[i]));
...@@ -150,14 +145,12 @@ std::shared_ptr<Scope> BuildScope(const Target& target, ...@@ -150,14 +145,12 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
}; };
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) { for (auto& oprand : (*it)->operands()) {
auto in_value = (*it)->operand_source(i); create_var(oprand.source());
create_var(CompatibleInfo::kInputPrefix, in_value);
} }
for (auto i = 0; i < (*it)->num_results(); ++i) { for (auto& result : (*it)->results()) {
auto out_value = (*it)->result(i); create_var(result);
create_var(CompatibleInfo::kOutputPrefix, out_value);
} }
} }
return scope; return scope;
......
...@@ -40,11 +40,11 @@ class NewIRCompiler final { ...@@ -40,11 +40,11 @@ class NewIRCompiler final {
std::unique_ptr<Program> Build(); std::unique_ptr<Program> Build();
std::unique_ptr<Program> Build(const std::vector<newir::GroupPtr>& groups);
private: private:
CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler); CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler);
std::unique_ptr<Program> Build(const std::vector<newir::GroupPtr>& groups);
std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx); std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx);
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs); void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
...@@ -31,10 +32,15 @@ ...@@ -31,10 +32,15 @@
#include "paddle/cinn/hlir/framework/convert_to_dialect.h" #include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include "paddle/cinn/hlir/framework/new_ir_compiler.h" #include "paddle/cinn/hlir/framework/new_ir_compiler.h"
std::unique_ptr<::ir::Program> BuildProgram() { using cinn::hlir::framework::newir::Group;
using cinn::hlir::framework::newir::GroupPtr;
using ProgramInfo =
std::tuple<std::shared_ptr<::ir::Program>, std::vector<GroupPtr>>;
ProgramInfo BuildProgram() {
::ir::IrContext* ctx = ::ir::IrContext::Instance(); ::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
auto program = std::make_unique<::ir::Program>(ctx); auto program = std::make_shared<::ir::Program>(ctx);
::ir::Builder builder = ::ir::Builder(ctx, program->block()); ::ir::Builder builder = ::ir::Builder(ctx, program->block());
const float value_one = 1.0; // relu(tan(1.)) = 1.5; const float value_one = 1.0; // relu(tan(1.)) = 1.5;
...@@ -51,17 +57,30 @@ std::unique_ptr<::ir::Program> BuildProgram() { ...@@ -51,17 +57,30 @@ std::unique_ptr<::ir::Program> BuildProgram() {
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
phi::GPUPlace()); phi::GPUPlace());
auto tanh_op_x = builder.Build<paddle::dialect::TanOp>(full_op_x->result(0)); auto tan_op_x = builder.Build<paddle::dialect::TanOp>(full_op_x->result(0));
auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tanh_op_x->result(0)); auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tan_op_x->result(0));
auto tanh_op_y = builder.Build<paddle::dialect::TanOp>(full_op_y->result(0)); auto tan_op_y = builder.Build<paddle::dialect::TanOp>(relu_op_x->result(0));
auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tanh_op_y->result(0)); auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tan_op_y->result(0));
return std::move(program); std::vector<GroupPtr> groups;
groups.emplace_back(
std::make_shared<Group>(std::initializer_list<::ir::Operation*>(
{full_op_x.operation()}))); // For coverage
groups.emplace_back(std::make_shared<Group>(
std::initializer_list<::ir::Operation*>({full_op_y.operation()})));
groups.emplace_back(std::make_shared<Group>(
std::vector<::ir::Operation*>({tan_op_x.operation(),
relu_op_x.operation(),
tan_op_y.operation(),
relu_op_y.operation()})));
return {program, groups};
} }
TEST(NewIRCompier, CompilerAndRun) { TEST(NewIRCompier, CompilerAndRun) {
// Step 1: Construct ir::Program // Step 1: Construct ir::Program
std::unique_ptr<::ir::Program> program = BuildProgram(); auto prog_info = BuildProgram();
std::shared_ptr<::ir::Program> program = std::get<0>(prog_info);
EXPECT_EQ(program->block()->size(), 6u); EXPECT_EQ(program->block()->size(), 6u);
LOG(INFO) << program->block()->size(); LOG(INFO) << program->block()->size();
...@@ -89,9 +108,42 @@ TEST(NewIRCompier, CompilerAndRun) { ...@@ -89,9 +108,42 @@ TEST(NewIRCompier, CompilerAndRun) {
} }
} }
TEST(NewIRCompier, CompileGroupOps) {
// Step 1: Construct ir::Program
auto prog_info = BuildProgram();
std::shared_ptr<::ir::Program> program = std::get<0>(prog_info);
std::vector<GroupPtr> groups = std::get<1>(prog_info);
EXPECT_EQ(program->block()->size(), 6u);
LOG(INFO) << program->block()->size();
std::stringstream ss;
program->Print(ss);
LOG(INFO) << ss.str();
// Step 2: Compiler New ir::Program into Runtime Program
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, *program);
ASSERT_EQ(scope->var_names().size(), 6);
cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope);
auto runtime_program = ir_compiler.Build(groups);
// Step 3: Execute Runtime Instruction and check Scope.
ASSERT_NO_THROW(runtime_program->Execute());
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
std::vector<float> data =
cinn::GetTensorData<float>(scope->GetTensor(name), target);
for (int i = 0; i < 1; ++i) {
LOG_FIRST_N(INFO, 10) << "data: " << data[i];
}
}
}
TEST(RuntimeDialect, CompilerAndRun) { TEST(RuntimeDialect, CompilerAndRun) {
// Step 1: Construct ir::Program // Step 1: Construct ir::Program
std::unique_ptr<::ir::Program> program = BuildProgram(); auto prog_info = BuildProgram();
std::shared_ptr<::ir::Program> program = std::get<0>(prog_info);
EXPECT_EQ(program->block()->size(), 6u); EXPECT_EQ(program->block()->size(), 6u);
// Step 2: Compiler New ir::Program into Runtime Program // Step 2: Compiler New ir::Program into Runtime Program
...@@ -103,7 +155,7 @@ TEST(RuntimeDialect, CompilerAndRun) { ...@@ -103,7 +155,7 @@ TEST(RuntimeDialect, CompilerAndRun) {
auto runtime_program = ir_compiler.Build(); auto runtime_program = ir_compiler.Build();
// Step 3: Convert into cinn::dialect::RuntimeDialect // Step 3: Convert into cinn::dialect::RuntimeDialect
std::unique_ptr<::ir::Program> ir_runtime_program = std::shared_ptr<::ir::Program> ir_runtime_program =
cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program); cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program);
// Step 4: Run cinn::dialect::RuntimeDialect // Step 4: Run cinn::dialect::RuntimeDialect
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册