未验证 提交 72a910e4 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Replace frontend::Program & hlir::Graph with ::ir::Program in CINN (#55186)

上级 fd192303
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/utils/attribute_util.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace cinn {
namespace hlir {
namespace framework {
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN.
static const std::unordered_map<std::string, std::string> OP_NAMES = {
{"pd.full", "fill_constant"}, {"pd.matmul", "matmul"}};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
class NewIRCompiler final {
public:
NewIRCompiler(const ::ir::Program& prog,
const Target& target,
const std::shared_ptr<Scope>& scope)
: program_(prog),
m_builder_("NewIR", target), // TODO(dev): need unique name
target_(target),
scope_(scope) {}
std::unique_ptr<Program> Build() {
m_builder_.Clear();
// NOTE(Aurelius84): Currently only support each op for one group
std::vector<std::vector<::ir::Operation*>> groups;
for (auto it = program_.block()->begin(); it != program_.block()->end();
++it) {
groups.push_back({*it});
}
VLOG(4) << "Groups size: " << groups.size();
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(GetOpFunc(*groups[i][0], i));
}
for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}
compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");
auto instructions = BuildInstructions(groups);
return std::make_unique<Program>(scope_, std::move(instructions));
}
std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx) {
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
VLOG(4) << "GetOpFunc for op: " << op.name();
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = std::to_string(std::hash<::ir::Value>()(in_value));
// NOTE(Aurelius84): whether need to support other Type?
auto type_info =
in_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
ir::Tensor temp;
auto dtype = type_info.dtype();
// TODO(Aurelius84): support more type
if (dtype.isa<::ir::Float32Type>()) {
temp = lang::Placeholder<float>(input_id, in_shape);
} else if (dtype.isa<::ir::Int32Type>()) {
temp = lang::Placeholder<int>(input_id, in_shape);
}
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
for (auto out_name : OpGetOutputNames(op)) {
cinn_inputs.push_back(
common::CINNValue(op.name().substr(3) + "_" + out_name));
}
VLOG(4) << "inputs.size(): " << inputs.size();
// step 2: Deal with OpResult
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
for (int i = 0; i < op.num_results(); ++i) {
auto out_value = op.result(i);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
// TODO(Aurelius84): need to support ::ir::Type -> common::Type
out_types.push_back(common::Float(32));
auto out_shape = phi::vectorize<int>(type_info.dims());
out_shapes.push_back(std::move(out_shape));
}
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs;
{
VLOG(4) << "op.attributes():" << op.attributes().size();
auto attrs = utils::ConvertAttributes(op.attributes());
node_attrs.node_name = OP_NAMES.at(op.name());
node_attrs.attr_store = std::move(attrs);
}
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
// ::ir::Program ?
const hlir::framework::Operator* cinn_op =
Operator::Get(OP_NAMES.at(op.name()));
auto impl = OpStrategy::SelectImpl(
strategy[cinn_op](node_attrs, inputs, out_types, out_shapes, target_));
common::CINNValuePack C =
impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if ((!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(
GenOpFuncName(op, idx), stages, inputs, {}, {}, nullptr, target_);
return func;
}
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs) {
for (auto&& func : lowered_funcs) {
for (auto&& arg : func->args) {
std::string arg_name = arg.name();
if (arg_name[0] == '_') arg_name = arg_name.substr(1);
auto* var = scope_->FindVar(arg_name);
// For argument buffer not in scope, create it.
if (!var && arg.is_buffer()) {
auto* new_var = scope_->Var<Tensor>(arg_name);
auto& tensor = absl::get<Tensor>(*new_var);
std::vector<Shape::dim_t> shape;
for (auto& shape_dim : arg.buffer_arg()->shape) {
CHECK(shape_dim.is_constant());
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->set_type(arg.buffer_arg()->dtype);
}
}
m_builder_.AddFunction(func);
}
}
std::vector<std::unique_ptr<Instruction>> BuildInstructions(
const std::vector<std::vector<::ir::Operation*>>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *groups[idx][0];
auto instr_name = op.name();
auto instr =
std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(),
OpGetInputNames(op),
OpGetOutputNames(op),
instr_name));
auto& op_func_name = GenOpFuncName(op, idx);
auto* fn_ptr = compiler_->Lookup(op_func_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), op_func_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
instr->Finalize();
instructions.push_back(std::move(instr));
}
return instructions;
}
protected:
const std::string& GenOpFuncName(const ::ir::Operation& op, int idx) {
// TODO(Aurelius84): . will raise compiler error in pd.xxx, need more
// elegant way to generate function name.
std::string op_name = op.name().substr(3) + "_" + std::to_string(idx);
std::string func_name = Context::Global().NewName("fn_" + op_name);
func_names_.try_emplace(op_name, func_name);
return func_names_.at(op_name);
}
std::vector<std::string> OpGetInputNames(const ::ir::Operation& op) {
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i);
std::string name = std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
continue;
}
repeat.insert(name);
names.push_back(name);
}
return names;
}
std::vector<std::string> OpGetOutputNames(const ::ir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
std::string name = std::to_string(std::hash<::ir::Value>()(value));
names.push_back(std::move(name));
}
return names;
}
private:
const ::ir::Program& program_;
ir::Module::Builder m_builder_;
std::unique_ptr<backends::Compiler> compiler_;
Target target_;
std::shared_ptr<Scope> scope_;
std::unordered_map<std::string, std::string> func_names_;
};
std::shared_ptr<Scope> BuildScope(const Target& target,
const ::ir::Program& program) {
std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>();
auto create_var = [&](::ir::Value value) {
if (visited.count(value) > 0) return;
visited.emplace(value);
std::string name = std::to_string(std::hash<::ir::Value>()(value));
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?
std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i]));
}
tensor->Resize(Shape{shape});
// TODO(Aurelius84): need convert this.
tensor->set_type(common::Float(32));
};
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
// visit OpOprands
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
create_var(in_value);
}
for (auto i = 0; i < (*it)->num_results(); ++i) {
auto out_value = (*it)->result(i);
create_var(out_value);
}
}
return scope;
}
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -32,21 +32,47 @@ CINNSchedule GetElementwiseScheduleFunc(
CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is "
"empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
CHECK_GT(arg_pack.size(), 0U)
<< "arg_pack.size() must contains at least one element.";
// TODO(Aurelius84): For NewIrCompiler, the outputs of Compute are
// tensor_ref and not Expr.
bool is_tensor_stages = arg_pack.size() == 2U && arg_pack[0].is_tensor() &&
arg_pack[1].is_stagemap();
if (!is_tensor_stages) {
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is "
"empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
pe::CudaScheduleInjective(
stages[out.as_tensor_ref()], output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()],
output_shapes.front(),
target,
vectorizable);
}
*ret = arg_pack;
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
});
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
namespace cinn {
namespace utils {
using NewIR_AttributeMap = std::unordered_map<std::string, ::ir::Attribute>;
Attribute ConvertAttribute(const ::ir::Attribute& src_attr) {
Attribute dst_attr;
if (src_attr.isa<::ir::BoolAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::BoolAttribute>().data();
} else if (src_attr.isa<::ir::FloatAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::FloatAttribute>().data();
} else if (src_attr.isa<::ir::Int32Attribute>()) {
dst_attr = src_attr.dyn_cast<::ir::Int32Attribute>().data();
} else if (src_attr.isa<::ir::StrAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::StrAttribute>().AsString();
} else if (src_attr.isa<::ir::Int64Attribute>()) {
dst_attr = src_attr.dyn_cast<::ir::Int64Attribute>().data();
} else if (src_attr.isa<::ir::DoubleAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::DoubleAttribute>().data();
} else if (src_attr.isa<paddle::dialect::IntArrayAttribute>()) {
auto arr = src_attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
std::vector<int> val;
for (size_t i = 0; i < arr.size(); ++i) {
val.push_back(arr[i]);
}
dst_attr = val;
} else if (src_attr.isa<paddle::dialect::DataTypeAttribute>()) {
// TODO(Aurelius84): Need add convert logic from phi::DataType into cinn
// String.
auto dtype = src_attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
dst_attr = phi::DataTypeToString(dtype);
} else {
LOG(FATAL) << "unknown Attribute: " << src_attr;
}
return dst_attr;
}
AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (!item.second.isa<paddle::dialect::PlaceAttribute>()) {
dst_attrs[item.first] = std::move(ConvertAttribute(item.second));
} else {
// TODO(Aurelius84): support place attribute for special Op
dst_attrs["force_cpu"] = false;
}
}
VLOG(4) << "dst_attrs.size(): " << dst_attrs.size();
return dst_attrs;
}
} // namespace utils
} // namespace cinn
......@@ -2,3 +2,4 @@ add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
add_subdirectory(kernel_dialect)
add_subdirectory(cinn)
if(WITH_TESTING AND WITH_CINN)
cc_test_old(
test_graph_compiler_new_ir
SRCS
graph_compiler_new_ir_test.cc
DEPS
cinncore
pd_dialect
ir
phi
gtest
glog)
set_tests_properties(test_graph_compiler_new_ir PROPERTIES LABELS
"RUN_TYPE=CINN")
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <sstream>
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
TEST(GraphCompier, TestNewIR) {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
::ir::Program program(ctx);
::ir::Builder builder = ::ir::Builder(ctx, program.block());
auto full_op_x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128},
1.0,
phi::DataType::FLOAT32,
phi::CPUPlace());
auto full_op_y =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64},
2.0,
phi::DataType::FLOAT32,
phi::CPUPlace());
// TODO(Aurelius84): test more op
// auto add_z = builder.Build<paddle::dialect::MatmulOp>(full_op_x->result(0),
// full_op_y->result(0));
EXPECT_EQ(program.block()->size(), 2u);
std::stringstream ss;
program.Print(ss);
LOG(INFO) << ss.str();
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, program);
ASSERT_EQ(scope->var_names().size(), 2);
cinn::hlir::framework::NewIRCompiler ir_compiler(program, target, scope);
auto runtime_program = ir_compiler.Build();
// FIXME(Aurelius84): It raised illegal memory access while deconstructor
// after running all instruction, but it's ok under GLOG_v=10.
// ASSERT_NO_THROW(runtime_program->Execute());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册