未验证 提交 4191f2c6 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Split NewIRCompiler with .h/.cc and decoupling compilation with cinncore (#55733)

* [NewIR]Split NewIRCompiler with .h/.cc and decoupling compilatiom with cinncore

* fix cmake

* fix CINN_ONLY
上级 ff2142f2
......@@ -9,6 +9,7 @@ gather_srcs(
buffer.cc
memory.cc
instruction.cc
program.cc
parallel_compiler.cc
graph_compiler.cc
graph.cc
......@@ -20,6 +21,13 @@ gather_srcs(
accuracy_checker.cc
visualize_helper.cc)
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinncore
pd_dialect)
endif()
if(WITH_CUDA)
cinn_nv_test(test_hlir_framework_buffer SRCS buffer_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_accuracy_checker SRCS
......
......@@ -37,220 +37,6 @@ namespace framework {
using cinn::common::bfloat16;
using cinn::common::float16;
// Store params from node to instruction
void AddAttrs(const absl::flat_hash_map<std::string, AttrType>& attrs_store,
const std::vector<std::string>& attrs_name,
Instruction* instr) {
for (auto& attr : attrs_name) {
if (attrs_store.find(attr) != attrs_store.end()) {
switch (attrs_store.at(attr).index()) {
case 2:
instr->attrs.push_back(absl::get<int>(attrs_store.at(attr)));
break;
case 3:
instr->str_attrs.push_back(
absl::get<std::string>(attrs_store.at(attr)));
break;
case 5:
auto temp = absl::get<std::vector<int>>(attrs_store.at(attr));
instr->attrs.insert(instr->attrs.end(), temp.begin(), temp.end());
break;
}
} else {
LOG(ERROR) << "Param " << attr << " missed! Please check.";
}
}
}
Program::Program(const std::shared_ptr<Scope>& scope,
std::vector<std::unique_ptr<Instruction>>&& instrs)
: scope_(scope) {
for (auto& ins : instrs) {
if (ins->pre_run) {
prerun_instrs_.push_back(std::move(ins));
} else {
instrs_.push_back(std::move(ins));
}
}
}
void Program::PreRun(
const std::map<std::string, cinn_pod_value_t>* name2podargs) {
for (auto& ins : prerun_instrs_) {
ins->Run(name2podargs);
}
for (auto& ins : instrs_) {
if (ins->size() == 4) {
ins->PreRun(name2podargs);
}
}
}
void Program::Export(const std::vector<std::string>& persistent_vars,
const std::string& filename) {
auto writeplaceholder = [=](int s, int n, FILE* f) -> int {
int pos = ftell(f);
for (int i = 0; i < s * n; i++) {
fwrite("\0", 1, 1, f);
}
return pos;
};
auto setplaceholder = [=](int p, void* b, int s, int n, FILE* f) {
int cur = ftell(f);
fseek(f, p, SEEK_SET);
fwrite(b, s, n, f);
fseek(f, cur, SEEK_SET);
};
auto tellplaceholder = [=](int p, FILE* f) {
int cur = ftell(f);
setplaceholder(p, &cur, 4, 1, f);
};
auto padding = [=](int alignment, uint8_t value, FILE* f) {
int cur = ftell(f);
int padding = (alignment - (cur % alignment)) % alignment;
for (int i = 0; i < padding; i++) {
fwrite(&value, 1, 1, f);
}
};
auto varnames = scope_->var_names();
std::unordered_map<std::string, int> varindex;
for (int i = 0; i < varnames.size(); i++) {
varindex[(std::string)varnames[i]] = i;
}
FILE* f = fopen(filename.c_str(), "w+");
fwrite("CINN", 4, 1, f);
int major_v = 0;
int minor_v = 0;
fwrite(&major_v, 4, 1, f);
fwrite(&minor_v, 4, 1, f);
int unused_v = 0;
fwrite(&unused_v, 4, 1, f);
// varname list
int varnamesec = writeplaceholder(4, 1, f);
int namesnum = varnames.size();
fwrite(&namesnum, 4, 1, f);
int nameoffset = writeplaceholder(4, namesnum, f);
for (int i = 0; i < namesnum; i++) {
int namelen = varnames[i].size();
fwrite(&namelen, 4, 1, f);
tellplaceholder(nameoffset + i * 4, f);
fwrite(varnames[i].data(), namelen, 1, f);
fwrite("\0", 1, 1, f);
}
padding(16, 0, f);
tellplaceholder(varnamesec, f);
// pod_values
int buffersec = writeplaceholder(4, 1, f);
int bufoffset = writeplaceholder(4, 1, f);
padding(alignof(cinn_buffer_t), 0, f);
tellplaceholder(bufoffset, f);
std::vector<std::pair<cinn_buffer_t*, int>> pvars;
for (auto& varname : varnames) {
std::string name = (std::string)varname;
auto t = scope_->GetTensor(name);
cinn_buffer_t buffer = *t->buffer();
buffer.memory = reinterpret_cast<uint8_t*>(0);
if (std::find(persistent_vars.begin(), persistent_vars.end(), name) !=
persistent_vars.end()) {
pvars.emplace_back(t->buffer(),
ftell(f) + offsetof(cinn_buffer_t, memory));
}
fwrite(&buffer, sizeof(cinn_buffer_t), 1, f);
}
padding(16, 0, f);
tellplaceholder(buffersec, f);
// persistent_buffers
int pbuffer = writeplaceholder(4, 1, f);
for (auto& p : pvars) {
if (p.first->align) {
padding(p.first->align, 0, f);
}
tellplaceholder(p.second, f);
fwrite(p.first->memory, p.first->memory_size, 1, f);
}
padding(16, 0, f);
tellplaceholder(pbuffer, f);
// instructions
int instsec = writeplaceholder(4, 1, f);
int insnum = 0;
for (auto& ins : instrs_) {
ins->Run(nullptr, true);
insnum += ins->GetFnNames().size();
}
fwrite(&insnum, 4, 1, f);
int instplaceholder = writeplaceholder(4 * 3, insnum, f);
int findex = 0;
for (auto& ins : instrs_) {
auto in_args = ins->GetInArgs();
auto out_args = ins->GetOutArgs();
auto fn_names = ins->GetFnNames();
for (int i = 0; i < fn_names.size(); i++, findex++) {
std::vector<std::string> all_args(in_args[i].begin(), in_args[i].end());
all_args.insert(
std::end(all_args), out_args[i].begin(), out_args[i].end());
auto fname = fn_names[i];
int fnamesize = fname.size();
fwrite(&fnamesize, 4, 1, f);
tellplaceholder(instplaceholder + findex * 12, f);
fwrite(fname.c_str(), fname.size(), 1, f);
fwrite("\0", 1, 1, f);
int argsize = all_args.size();
setplaceholder(instplaceholder + findex * 12 + 4, &argsize, 4, 1, f);
padding(alignof(cinn_pod_value_t), 0, f);
tellplaceholder(instplaceholder + findex * 12 + 8, f);
for (auto& arg : all_args) {
uintptr_t bufindex = varindex[arg];
cinn_pod_value_t v(reinterpret_cast<cinn_buffer_t*>(bufindex));
fwrite(&v, sizeof(cinn_pod_value_t), 1, f);
}
}
}
padding(16, 0, f);
tellplaceholder(instsec, f);
fclose(f);
}
void Program::Execute(
const std::map<std::string, cinn_pod_value_t>* name2podargs,
void* stream,
bool use_cache) {
for (auto& ins : instrs_) {
ins->Run(name2podargs, false, stream, use_cache);
}
#ifdef CINN_WITH_CUDA
VLOG(4) << "-- The value of the used stream: " << stream;
if (instrs_[0]->target_.arch == Target::Arch::NVGPU && stream == nullptr) {
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
}
void Program::ExecuteTest(int repeat_) {
cinn::utils::Timer timer1;
for (int i = 0; i < 100; i++) {
for (auto& ins : instrs_) {
ins->Run();
}
}
timer1.Start();
for (int i = 0; i < repeat_; i++) {
for (auto& ins : instrs_) {
ins->Run();
}
}
#ifdef CINN_WITH_CUDA
if (instrs_[0]->target_.arch == Target::Arch::NVGPU) {
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
double test_op_time = timer1.Stop() / repeat_;
VLOG(3) << "Repeat times: [" << repeat_ << "], average op time: ["
<< test_op_time << "] ms";
}
std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph);
GraphCompiler::CompileOptions options;
......
......@@ -31,6 +31,7 @@
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/lang/packed_func.h"
......@@ -40,56 +41,6 @@ namespace cinn {
namespace hlir {
namespace framework {
/**
* The Program is the runtime instance for running a computation.
*/
class Program {
public:
/**
* Constructor.
* @param scope The scope containing all the runtime variables.
* @param instrs The instructions belonging to this program.
*/
Program(const std::shared_ptr<Scope>& scope,
std::vector<std::unique_ptr<Instruction>>&& instrs);
void PreRun(
const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr);
void Export(const std::vector<std::string>& persistent_vars,
const std::string& filename);
/**
* Execute the program -- that is running all the instructions inside it.
*/
void Execute(
const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr,
void* stream = nullptr,
bool use_cache = true);
void ExecuteTest(int repeat_);
/**
* Get the number of instructions.
*/
size_t size() const { return instrs_.size(); }
const std::vector<std::unique_ptr<Instruction>>& GetPreRunInstructions() {
return prerun_instrs_;
}
const std::vector<std::unique_ptr<Instruction>>& GetRunInstructions() {
return instrs_;
}
private:
// We need to hold scope to assure tensors alive used in instructions.
std::shared_ptr<Scope> scope_;
// prerun instructions
std::vector<std::unique_ptr<Instruction>> prerun_instrs_;
// only runtime instructions
std::vector<std::unique_ptr<Instruction>> instrs_;
};
/**
* GraphCompiler compiles a graph and generate the runtime Program.
*/
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
#include <absl/types/variant.h>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/utils/attribute_util.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_type.h"
namespace cinn {
namespace hlir {
namespace framework {
const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd.full", "fill_constant"}, {"pd.matmul", "matmul"}};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
std::unique_ptr<Program> NewIRCompiler::Build() {
m_builder_.Clear();
// NOTE(Aurelius84): Currently only support each op for one group
std::vector<std::vector<::ir::Operation*>> groups;
for (auto it = program_.block()->begin(); it != program_.block()->end();
++it) {
groups.push_back({*it});
}
VLOG(4) << "Groups size: " << groups.size();
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(GetOpFunc(*groups[i][0], i));
}
for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}
compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");
auto instructions = BuildInstructions(groups);
// TODO(Aurelius84): Instantiate all tensors on compile-time, which is
// controlled by 'options.with_instantiate_variables' in GraphCompiler.
// Moreover, it's better to implement InsertBufferHandlers() logic
// to automatically insert Malloc and Free instructions.
for (auto& name : scope_->var_names()) {
std::string var_name({name.data(), name.size()});
VLOG(4) << "Instantiate " << var_name << " on compile-time";
auto* var = scope_->Var<Tensor>(var_name);
auto& tensor = absl::get<Tensor>(*var);
tensor->mutable_data(target_, tensor->type());
}
return std::make_unique<Program>(scope_, std::move(instructions));
}
std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op,
int idx) {
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
auto op_name = op.name();
VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
auto type_info =
in_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
auto dtype = type_info.dtype();
ir::Tensor temp = lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id);
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
for (auto out_name : OpGetOutputNames(op)) {
cinn_inputs.push_back(common::CINNValue(out_name));
}
VLOG(4) << "inputs.size(): " << inputs.size();
// step 2: Deal with OpResult
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
for (int i = 0; i < op.num_results(); ++i) {
auto out_value = op.result(i);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
out_types.push_back(utils::ConvertIRType(type_info.dtype()));
auto out_shape = phi::vectorize<int>(type_info.dims());
out_shapes.push_back(std::move(out_shape));
}
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs;
{
VLOG(4) << "op.attributes():" << op.attributes().size();
auto attrs = utils::ConvertAttributes(op.attributes());
node_attrs.node_name = CompatibleInfo::OP_NAMES.at(op_name);
node_attrs.attr_store = std::move(attrs);
}
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
// ::ir::Program ?
const hlir::framework::Operator* cinn_op =
Operator::Get(CompatibleInfo::OP_NAMES.at(op_name));
auto impl = OpStrategy::SelectImpl(
strategy[cinn_op](node_attrs, inputs, out_types, out_shapes, target_));
common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if ((!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(
GenOpFuncName(op, idx), stages, inputs, {}, {}, nullptr, target_);
return func;
}
void NewIRCompiler::ProcessFunction(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
for (auto&& func : lowered_funcs) {
for (auto&& arg : func->args) {
std::string arg_name = arg.name();
if (arg_name[0] == '_') arg_name = arg_name.substr(1);
auto* var = scope_->FindVar(arg_name);
// For argument buffer not in scope, create it.
if (!var && arg.is_buffer()) {
auto* new_var = scope_->Var<Tensor>(arg_name);
auto& tensor = absl::get<Tensor>(*new_var);
std::vector<Shape::dim_t> shape;
for (auto& shape_dim : arg.buffer_arg()->shape) {
CHECK(shape_dim.is_constant());
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->set_type(arg.buffer_arg()->dtype);
}
}
m_builder_.AddFunction(func);
}
}
std::vector<std::unique_ptr<Instruction>> NewIRCompiler::BuildInstructions(
const std::vector<std::vector<::ir::Operation*>>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *groups[idx][0];
auto instr_name = op.name();
auto instr =
std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(),
OpGetInputNames(op),
OpGetOutputNames(op),
instr_name));
auto& op_func_name = GenOpFuncName(op, idx);
auto* fn_ptr = compiler_->Lookup(op_func_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), op_func_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
instr->Finalize();
instructions.push_back(std::move(instr));
}
return instructions;
}
const std::string& NewIRCompiler::GenOpFuncName(const ::ir::Operation& op,
int idx) {
// TODO(Aurelius84): . will raise compiler error in pd.xxx, need more
// elegant way to generate function name.
std::string op_name = op.name().substr(3) + "_" + std::to_string(idx);
std::string func_name = Context::Global().NewName("fn_" + op_name);
func_names_.try_emplace(op_name, func_name);
return func_names_.at(op_name);
}
std::vector<std::string> NewIRCompiler::OpGetInputNames(
const ::ir::Operation& op) {
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i);
std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
continue;
}
repeat.insert(name);
names.push_back(name);
}
return names;
}
std::vector<std::string> NewIRCompiler::OpGetOutputNames(
const ::ir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
std::string name = CompatibleInfo::kOutputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
names.push_back(std::move(name));
}
return names;
}
std::shared_ptr<Scope> BuildScope(const Target& target,
const ::ir::Program& program) {
std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>();
auto create_var = [&](const std::string& name_prefix, ::ir::Value value) {
if (visited.count(value) > 0) return;
visited.emplace(value);
std::string name =
name_prefix + std::to_string(std::hash<::ir::Value>()(value));
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?
std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i]));
}
tensor->Resize(Shape{shape});
tensor->set_type(utils::ConvertIRType(type_info.dtype()));
};
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
}
for (auto i = 0; i < (*it)->num_results(); ++i) {
auto out_value = (*it)->result(i);
create_var(CompatibleInfo::kOutputPrefix, out_value);
}
}
return scope;
}
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -13,16 +13,10 @@
// limitations under the License.
#pragma once
#include <absl/types/variant.h>
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/utils/attribute_util.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
......@@ -39,9 +33,6 @@ struct CompatibleInfo {
static const std::unordered_map<std::string, std::string> OP_NAMES;
};
const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd.full", "fill_constant"}, {"pd.matmul", "matmul"}};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
class NewIRCompiler final {
......@@ -53,255 +44,33 @@ class NewIRCompiler final {
m_builder_("NewIR", target),
target_(target),
scope_(scope) {}
std::unique_ptr<Program> Build() {
m_builder_.Clear();
// NOTE(Aurelius84): Currently only support each op for one group
std::vector<std::vector<::ir::Operation*>> groups;
for (auto it = program_.block()->begin(); it != program_.block()->end();
++it) {
groups.push_back({*it});
}
VLOG(4) << "Groups size: " << groups.size();
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(GetOpFunc(*groups[i][0], i));
}
for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}
compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");
auto instructions = BuildInstructions(groups);
// TODO(Aurelius84): Instantiate all tensors on compile-time, which is
// controlled by 'options.with_instantiate_variables' in GraphCompiler.
// Moreover, it's better to implement InsertBufferHandlers() logic
// to automatically insert Malloc and Free instructions.
for (auto& name : scope_->var_names()) {
std::string var_name({name.data(), name.size()});
VLOG(4) << "Instantiate " << var_name << " on compile-time";
auto* var = scope_->Var<Tensor>(var_name);
auto& tensor = absl::get<Tensor>(*var);
tensor->mutable_data(target_, tensor->type());
}
return std::make_unique<Program>(scope_, std::move(instructions));
}
std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx) {
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
auto op_name = op.name();
VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
auto type_info =
in_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
auto dtype = type_info.dtype();
ir::Tensor temp = lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id);
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
for (auto out_name : OpGetOutputNames(op)) {
cinn_inputs.push_back(common::CINNValue(out_name));
}
VLOG(4) << "inputs.size(): " << inputs.size();
// step 2: Deal with OpResult
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
for (int i = 0; i < op.num_results(); ++i) {
auto out_value = op.result(i);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
out_types.push_back(utils::ConvertIRType(type_info.dtype()));
auto out_shape = phi::vectorize<int>(type_info.dims());
out_shapes.push_back(std::move(out_shape));
}
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs;
{
VLOG(4) << "op.attributes():" << op.attributes().size();
auto attrs = utils::ConvertAttributes(op.attributes());
node_attrs.node_name = CompatibleInfo::OP_NAMES.at(op_name);
node_attrs.attr_store = std::move(attrs);
}
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
// ::ir::Program ?
const hlir::framework::Operator* cinn_op =
Operator::Get(CompatibleInfo::OP_NAMES.at(op_name));
auto impl = OpStrategy::SelectImpl(
strategy[cinn_op](node_attrs, inputs, out_types, out_shapes, target_));
common::CINNValuePack C =
impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if ((!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(
GenOpFuncName(op, idx), stages, inputs, {}, {}, nullptr, target_);
return func;
}
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs) {
for (auto&& func : lowered_funcs) {
for (auto&& arg : func->args) {
std::string arg_name = arg.name();
if (arg_name[0] == '_') arg_name = arg_name.substr(1);
auto* var = scope_->FindVar(arg_name);
// For argument buffer not in scope, create it.
if (!var && arg.is_buffer()) {
auto* new_var = scope_->Var<Tensor>(arg_name);
auto& tensor = absl::get<Tensor>(*new_var);
std::vector<Shape::dim_t> shape;
for (auto& shape_dim : arg.buffer_arg()->shape) {
CHECK(shape_dim.is_constant());
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->set_type(arg.buffer_arg()->dtype);
}
}
m_builder_.AddFunction(func);
}
}
std::unique_ptr<Program> Build();
std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx);
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs);
std::vector<std::unique_ptr<Instruction>> BuildInstructions(
const std::vector<std::vector<::ir::Operation*>>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *groups[idx][0];
auto instr_name = op.name();
auto instr =
std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(),
OpGetInputNames(op),
OpGetOutputNames(op),
instr_name));
auto& op_func_name = GenOpFuncName(op, idx);
auto* fn_ptr = compiler_->Lookup(op_func_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), op_func_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
instr->Finalize();
instructions.push_back(std::move(instr));
}
return instructions;
}
const std::vector<std::vector<::ir::Operation*>>& groups);
protected:
const std::string& GenOpFuncName(const ::ir::Operation& op, int idx) {
// TODO(Aurelius84): . will raise compiler error in pd.xxx, need more
// elegant way to generate function name.
std::string op_name = op.name().substr(3) + "_" + std::to_string(idx);
std::string func_name = Context::Global().NewName("fn_" + op_name);
func_names_.try_emplace(op_name, func_name);
return func_names_.at(op_name);
}
const std::string& GenOpFuncName(const ::ir::Operation& op, int idx);
std::vector<std::string> OpGetInputNames(const ::ir::Operation& op) {
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i);
std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
continue;
}
repeat.insert(name);
names.push_back(name);
}
return names;
}
std::vector<std::string> OpGetInputNames(const ::ir::Operation& op);
std::vector<std::string> OpGetOutputNames(const ::ir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
std::string name = CompatibleInfo::kOutputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
names.push_back(std::move(name));
}
return names;
}
std::vector<std::string> OpGetOutputNames(const ::ir::Operation& op);
private:
CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler);
const ::ir::Program& program_;
ir::Module::Builder m_builder_;
std::unique_ptr<backends::Compiler> compiler_;
std::unique_ptr<backends::Compiler> compiler_{nullptr};
Target target_;
std::shared_ptr<Scope> scope_;
std::unordered_map<std::string, std::string> func_names_;
};
std::shared_ptr<Scope> BuildScope(const Target& target,
const ::ir::Program& program) {
std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>();
auto create_var = [&](const std::string& name_prefix, ::ir::Value value) {
if (visited.count(value) > 0) return;
visited.emplace(value);
std::string name =
name_prefix + std::to_string(std::hash<::ir::Value>()(value));
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?
std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i]));
}
tensor->Resize(Shape{shape});
tensor->set_type(utils::ConvertIRType(type_info.dtype()));
};
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
}
for (auto i = 0; i < (*it)->num_results(); ++i) {
auto out_value = (*it)->result(i);
create_var(CompatibleInfo::kOutputPrefix, out_value);
}
}
return scope;
}
std::shared_ptr<Scope> BuildScope(const Target&, const ::ir::Program&);
} // namespace framework
} // namespace hlir
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/program.h"
namespace cinn {
namespace hlir {
namespace framework {
Program::Program(const std::shared_ptr<Scope>& scope,
std::vector<std::unique_ptr<Instruction>>&& instrs)
: scope_(scope) {
for (auto& ins : instrs) {
if (ins->pre_run) {
prerun_instrs_.push_back(std::move(ins));
} else {
instrs_.push_back(std::move(ins));
}
}
}
void Program::PreRun(
const std::map<std::string, cinn_pod_value_t>* name2podargs) {
for (auto& ins : prerun_instrs_) {
ins->Run(name2podargs);
}
for (auto& ins : instrs_) {
if (ins->size() == 4) {
ins->PreRun(name2podargs);
}
}
}
void Program::Export(const std::vector<std::string>& persistent_vars,
const std::string& filename) {
auto writeplaceholder = [=](int s, int n, FILE* f) -> int {
int pos = ftell(f);
for (int i = 0; i < s * n; i++) {
fwrite("\0", 1, 1, f);
}
return pos;
};
auto setplaceholder = [=](int p, void* b, int s, int n, FILE* f) {
int cur = ftell(f);
fseek(f, p, SEEK_SET);
fwrite(b, s, n, f);
fseek(f, cur, SEEK_SET);
};
auto tellplaceholder = [=](int p, FILE* f) {
int cur = ftell(f);
setplaceholder(p, &cur, 4, 1, f);
};
auto padding = [=](int alignment, uint8_t value, FILE* f) {
int cur = ftell(f);
int padding = (alignment - (cur % alignment)) % alignment;
for (int i = 0; i < padding; i++) {
fwrite(&value, 1, 1, f);
}
};
auto varnames = scope_->var_names();
std::unordered_map<std::string, int> varindex;
for (int i = 0; i < varnames.size(); i++) {
varindex[(std::string)varnames[i]] = i;
}
FILE* f = fopen(filename.c_str(), "w+");
fwrite("CINN", 4, 1, f);
int major_v = 0;
int minor_v = 0;
fwrite(&major_v, 4, 1, f);
fwrite(&minor_v, 4, 1, f);
int unused_v = 0;
fwrite(&unused_v, 4, 1, f);
// varname list
int varnamesec = writeplaceholder(4, 1, f);
int namesnum = varnames.size();
fwrite(&namesnum, 4, 1, f);
int nameoffset = writeplaceholder(4, namesnum, f);
for (int i = 0; i < namesnum; i++) {
int namelen = varnames[i].size();
fwrite(&namelen, 4, 1, f);
tellplaceholder(nameoffset + i * 4, f);
fwrite(varnames[i].data(), namelen, 1, f);
fwrite("\0", 1, 1, f);
}
padding(16, 0, f);
tellplaceholder(varnamesec, f);
// pod_values
int buffersec = writeplaceholder(4, 1, f);
int bufoffset = writeplaceholder(4, 1, f);
padding(alignof(cinn_buffer_t), 0, f);
tellplaceholder(bufoffset, f);
std::vector<std::pair<cinn_buffer_t*, int>> pvars;
for (auto& varname : varnames) {
std::string name = (std::string)varname;
auto t = scope_->GetTensor(name);
cinn_buffer_t buffer = *t->buffer();
buffer.memory = reinterpret_cast<uint8_t*>(0);
if (std::find(persistent_vars.begin(), persistent_vars.end(), name) !=
persistent_vars.end()) {
pvars.emplace_back(t->buffer(),
ftell(f) + offsetof(cinn_buffer_t, memory));
}
fwrite(&buffer, sizeof(cinn_buffer_t), 1, f);
}
padding(16, 0, f);
tellplaceholder(buffersec, f);
// persistent_buffers
int pbuffer = writeplaceholder(4, 1, f);
for (auto& p : pvars) {
if (p.first->align) {
padding(p.first->align, 0, f);
}
tellplaceholder(p.second, f);
fwrite(p.first->memory, p.first->memory_size, 1, f);
}
padding(16, 0, f);
tellplaceholder(pbuffer, f);
// instructions
int instsec = writeplaceholder(4, 1, f);
int insnum = 0;
for (auto& ins : instrs_) {
ins->Run(nullptr, true);
insnum += ins->GetFnNames().size();
}
fwrite(&insnum, 4, 1, f);
int instplaceholder = writeplaceholder(4 * 3, insnum, f);
int findex = 0;
for (auto& ins : instrs_) {
auto in_args = ins->GetInArgs();
auto out_args = ins->GetOutArgs();
auto fn_names = ins->GetFnNames();
for (int i = 0; i < fn_names.size(); i++, findex++) {
std::vector<std::string> all_args(in_args[i].begin(), in_args[i].end());
all_args.insert(
std::end(all_args), out_args[i].begin(), out_args[i].end());
auto fname = fn_names[i];
int fnamesize = fname.size();
fwrite(&fnamesize, 4, 1, f);
tellplaceholder(instplaceholder + findex * 12, f);
fwrite(fname.c_str(), fname.size(), 1, f);
fwrite("\0", 1, 1, f);
int argsize = all_args.size();
setplaceholder(instplaceholder + findex * 12 + 4, &argsize, 4, 1, f);
padding(alignof(cinn_pod_value_t), 0, f);
tellplaceholder(instplaceholder + findex * 12 + 8, f);
for (auto& arg : all_args) {
uintptr_t bufindex = varindex[arg];
cinn_pod_value_t v(reinterpret_cast<cinn_buffer_t*>(bufindex));
fwrite(&v, sizeof(cinn_pod_value_t), 1, f);
}
}
}
padding(16, 0, f);
tellplaceholder(instsec, f);
fclose(f);
}
void Program::Execute(
const std::map<std::string, cinn_pod_value_t>* name2podargs,
void* stream,
bool use_cache) {
for (auto& ins : instrs_) {
ins->Run(name2podargs, false, stream, use_cache);
}
#ifdef CINN_WITH_CUDA
VLOG(4) << "-- The value of the used stream: " << stream;
if (instrs_[0]->target_.arch == Target::Arch::NVGPU && stream == nullptr) {
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
}
void Program::ExecuteTest(int repeat_) {
cinn::utils::Timer timer1;
for (int i = 0; i < 100; i++) {
for (auto& ins : instrs_) {
ins->Run();
}
}
timer1.Start();
for (int i = 0; i < repeat_; i++) {
for (auto& ins : instrs_) {
ins->Run();
}
}
#ifdef CINN_WITH_CUDA
if (instrs_[0]->target_.arch == Target::Arch::NVGPU) {
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
double test_op_time = timer1.Stop() / repeat_;
VLOG(3) << "Repeat times: [" << repeat_ << "], average op time: ["
<< test_op_time << "] ms";
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/scope.h"
namespace cinn {
namespace hlir {
namespace framework {
/**
* The Program is the runtime instance for running a computation.
*/
class Program {
public:
/**
* Constructor.
* @param scope The scope containing all the runtime variables.
* @param instrs The instructions belonging to this program.
*/
Program(const std::shared_ptr<Scope>& scope,
std::vector<std::unique_ptr<Instruction>>&& instrs);
void PreRun(
const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr);
void Export(const std::vector<std::string>& persistent_vars,
const std::string& filename);
/**
* Execute the program -- that is running all the instructions inside it.
*/
void Execute(
const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr,
void* stream = nullptr,
bool use_cache = true);
void ExecuteTest(int repeat_);
/**
* Get the number of instructions.
*/
size_t size() const { return instrs_.size(); }
const std::vector<std::unique_ptr<Instruction>>& GetPreRunInstructions() {
return prerun_instrs_;
}
const std::vector<std::unique_ptr<Instruction>>& GetRunInstructions() {
return instrs_;
}
private:
// We need to hold scope to assure tensors alive used in instructions.
std::shared_ptr<Scope> scope_;
// prerun instructions
std::vector<std::unique_ptr<Instruction>> prerun_instrs_;
// only runtime instructions
std::vector<std::unique_ptr<Instruction>> instrs_;
};
} // namespace framework
} // namespace hlir
} // namespace cinn
if(WITH_TESTING AND WITH_CINN)
cc_test_old(
test_graph_compiler_new_ir
test_new_ir_compiler
SRCS
graph_compiler_new_ir_test.cc
new_ir_compiler_test.cc
DEPS
cinncore
pd_dialect
new_ir_compiler
ir
phi
gtest
glog)
set_tests_properties(test_graph_compiler_new_ir PROPERTIES LABELS
"RUN_TYPE=CINN")
set_tests_properties(test_new_ir_compiler PROPERTIES LABELS "RUN_TYPE=CINN")
endif()
......@@ -23,7 +23,6 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册