未验证 提交 12823f2b 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Add cinn RuntimeDialect and JitKernelOp (#56074)

* [NewIR]Add cinn RuntimeDialect and JitKernelOp

* remove PointerAttribute register

* fix comment
上级 fa878846
......@@ -3,3 +3,4 @@ add_subdirectory(pe)
add_subdirectory(op)
add_subdirectory(pass)
add_subdirectory(kernels)
add_subdirectory(dialect)
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
cinn_cc_library(cinn_dialect SRCS runtime_dialect.cc jit_kernel_op.cc DEPS
pd_dialect)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/enforce.h"
namespace cinn {
namespace dialect {
const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};
void JitKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";
auto& attributes = this->attributes();
IR_ENFORCE(attributes.count(kAttrName) > 0 &&
attributes.at(kAttrName).isa<::ir::PointerAttribute>(),
"Type of attribute: instruction is not right.");
}
hlir::framework::Instruction* JitKernelOp::instruction() {
void* ptr =
attributes().at(kAttrName).dyn_cast<ir::PointerAttribute>().data();
return reinterpret_cast<hlir::framework::Instruction*>(ptr);
}
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::JitKernelOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
namespace cinn {
namespace hlir {
namespace framework {
class Instruction;
} // namespace framework
} // namespace hlir
namespace dialect {
/*
* TODO(Aurelius84): THIS IS NOT FINAL STATE!
* JitKernel is unified runtime operation to represent
* jit compiled function ptr from backend, such as
* nvrct.
* Ideally, JitKernel should only contains ArrayAttribute
* with each element is PointerAttribute, which is jit
* function ptr indeed.
* Currently, we regard hlir::framework::Instruction
* temporarily, and will spilt executor information like
* scope, inputs, outputs into InterpretorCore module.
*/
class JitKernelOp : public ::ir::Op<JitKernelOp> {
public:
using Op::Op;
static const char* name() { return "cinn.jit_kernel"; }
// TODO(Aurelius84): Think deeply what should contains
static constexpr uint32_t attributes_num = 1;
static constexpr char* kAttrName = "instruction";
static const char* attributes_name[attributes_num];
hlir::framework::Instruction* instruction();
void Verify();
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::JitKernelOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
namespace cinn {
namespace dialect {
RuntimeDialect::RuntimeDialect(::ir::IrContext* context)
: ::ir::Dialect(
name(), context, ::ir::TypeId::get<cinn::dialect::RuntimeDialect>()) {
this->initialize();
}
void RuntimeDialect::initialize() { RegisterOps<cinn::dialect::JitKernelOp>(); }
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::RuntimeDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dialect.h"
namespace cinn {
namespace dialect {
class RuntimeDialect : public ::ir::Dialect {
public:
explicit RuntimeDialect(::ir::IrContext* context);
static const char* name() { return "cinn"; }
private:
void initialize();
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::RuntimeDialect)
......@@ -26,6 +26,8 @@ gather_srcs(
if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinncore
pd_dialect)
cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinncore
cinn_dialect)
endif()
if(WITH_CUDA)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include <string>
#include <unordered_map>
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
namespace cinn {
namespace hlir {
namespace framework {
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program) {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
auto ir_program = std::make_unique<::ir::Program>(ctx);
std::string jit_op_name = dialect::JitKernelOp::name();
::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name);
auto& instrs = program.GetRunInstructions();
for (auto& instr : instrs) {
std::unordered_map<std::string, ::ir::Attribute> op_attrs{
{dialect::JitKernelOp::kAttrName,
::ir::PointerAttribute::get(ctx, instr.get())},
};
::ir::Operation* cinn_op =
::ir::Operation::Create({}, op_attrs, {}, op_info);
ir_program->block()->push_back(cinn_op);
}
return std::move(ir_program);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace ir {
class Program;
} // namespace ir
namespace cinn {
namespace hlir {
namespace framework {
class Program;
std::unique_ptr<::ir::Program> ConvertToRuntimeDialect(
const hlir::framework::Program& program);
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -56,10 +56,11 @@ class Program {
*/
size_t size() const { return instrs_.size(); }
const std::vector<std::unique_ptr<Instruction>>& GetPreRunInstructions() {
const std::vector<std::unique_ptr<Instruction>>& GetPreRunInstructions()
const {
return prerun_instrs_;
}
const std::vector<std::unique_ptr<Instruction>>& GetRunInstructions() {
const std::vector<std::unique_ptr<Instruction>>& GetRunInstructions() const {
return instrs_;
}
......
......@@ -5,6 +5,7 @@ if(WITH_TESTING AND WITH_CINN)
new_ir_compiler_test.cc
DEPS
new_ir_compiler
convert_to_dialect
ir
phi
gtest
......
......@@ -14,7 +14,10 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
......@@ -25,13 +28,16 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
TEST(GraphCompier, TestNewIR) {
std::unique_ptr<::ir::Program> BuildProgram() {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
::ir::Program program(ctx);
::ir::Builder builder = ::ir::Builder(ctx, program.block());
auto program = std::make_unique<::ir::Program>(ctx);
::ir::Builder builder = ::ir::Builder(ctx, program->block());
const float value = 2.0;
auto full_op_x =
......@@ -47,23 +53,75 @@ TEST(GraphCompier, TestNewIR) {
phi::GPUPlace());
// TODO(Aurelius84): test more op
// auto add_z = builder.Build<paddle::dialect::MatmulOp>(full_op_x->result(0),
// full_op_y->result(0));
// full_op_y->result(0));
return std::move(program);
}
EXPECT_EQ(program.block()->size(), 2u);
TEST(NewIRCompier, CompilerAndRun) {
// Step 1: Construct ir::Program
std::unique_ptr<::ir::Program> program = BuildProgram();
EXPECT_EQ(program->block()->size(), 2u);
std::stringstream ss;
program.Print(ss);
program->Print(ss);
LOG(INFO) << ss.str();
// Step 2: Compiler New ir::Program into Runtime Program
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, program);
auto scope = cinn::hlir::framework::BuildScope(target, *program);
ASSERT_EQ(scope->var_names().size(), 2);
cinn::hlir::framework::NewIRCompiler ir_compiler(program, target, scope);
cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope);
auto runtime_program = ir_compiler.Build();
// Step 3: Execute Runtime Instruction and check Scope.
ASSERT_NO_THROW(runtime_program->Execute());
const float value = 2.0;
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
std::vector<float> data =
cinn::GetTensorData<float>(scope->GetTensor(name), target);
for (int i = 0; i < data.size(); ++i) {
LOG_FIRST_N(INFO, 3) << "data: " << data[i];
ASSERT_NEAR(data[i], value, 1e-5);
}
}
}
TEST(RuntimeDialect, CompilerAndRun) {
// Step 1: Construct ir::Program
std::unique_ptr<::ir::Program> program = BuildProgram();
EXPECT_EQ(program->block()->size(), 2u);
// Step 2: Compiler New ir::Program into Runtime Program
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, *program);
ASSERT_EQ(scope->var_names().size(), 2);
cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope);
auto runtime_program = ir_compiler.Build();
// Step 3: Convert into cinn::dialect::RuntimeDialect
std::unique_ptr<::ir::Program> ir_runtime_program =
cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program);
// Step 4: Run cinn::dialect::RuntimeDialect
for (auto iter = ir_runtime_program->block()->begin();
iter != ir_runtime_program->block()->end();
++iter) {
auto op = (*iter)->dyn_cast<cinn::dialect::JitKernelOp>();
auto* instr = op.instruction();
instr->Run(/*name2podargs=*/nullptr,
false,
/*stream=*/nullptr,
/*use_cache=*/true);
}
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaDeviceSynchronize());
#endif
// Step 5: Check Scope Tensor Value.
const float value = 2.0;
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
std::vector<float> data =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册