未验证 提交 b0193f3a 编写于 作者: F Fisher 提交者: GitHub

Clean unused old graph compiler (#55484)

In preparation for the improvement of the graph compiler, the deprecated old graph compiler was cleaned up.
上级 ee65599e
......@@ -31,7 +31,6 @@
#include "paddle/cinn/utils/profiler.h"
DECLARE_bool(cinn_ir_schedule);
DECLARE_int32(cinn_parallel_compile_size);
namespace cinn {
namespace hlir {
......@@ -254,530 +253,6 @@ void Program::ExecuteTest(int repeat_) {
<< test_op_time << "] ms";
}
void GraphCompiler::PrintFunc() {
auto topo_order = graph_->topological_order();
auto& nodes = std::get<0>(topo_order);
auto& edges = std::get<1>(topo_order);
for (auto& n : nodes) {
auto* node = n->safe_as<Node>();
if (node) {
auto lowered_func = GetOpFunc(node);
}
}
}
std::string GraphCompiler::GenSourceCode() {
auto topo_order = graph_->topological_order();
auto& nodes = std::get<0>(topo_order);
auto& edges = std::get<1>(topo_order);
for (auto& n : nodes) {
auto* node = n->safe_as<Node>();
if (node) {
auto lowered_func = GetOpFunc(node);
for (auto& i : lowered_func) {
m_builder_.AddFunction(i);
}
}
}
// // compile the module
if (!compiler_) {
compiler_ = backends::Compiler::Create(target_);
}
auto build_module = m_builder_.Build();
return compiler_->GetSourceCode(build_module);
}
const std::string& GraphCompiler::GetOrGenFullFuncName(
const std::string& prefix) {
// try_emplace only insert once, so the same function
// can get a consistent name next time
prefix2full_namemap_.try_emplace(prefix, Context::Global().NewName(prefix));
return prefix2full_namemap_.at(prefix);
}
std::vector<ir::LoweredFunc> GraphCompiler::GetOpFuncWithIRSchedule(
const Node* node,
const absl::flat_hash_map<std::string, Type>& type_dict_,
const absl::flat_hash_map<std::string, shape_t>& shape_dict_) {
// get input tensor and output tensor
auto& cinn_strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
std::vector<ir::Tensor> tensor_inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<std::string> input_output_nodes;
VLOG(3) << "GetOpFunc of op " << node->id();
// 1.Collect inputs info and outputs info
for (auto& i : node->inlinks_in_order()) {
std::string id = i->source()->as<NodeData>()->id();
auto shape = shape_dict_.at(id);
Type dtype = type_dict_.at(id);
CHECK(dtype.is_supported())
<< "The dtype of node " << id
<< " is not float or bool or int! Other dtype is not implemented yet.";
ir::Tensor input;
if (dtype.is_float(32)) {
input = lang::Placeholder<float>(id, shape);
} else if (dtype.is_float(64)) {
input = lang::Placeholder<double>(id, shape);
} else if (dtype.is_bfloat16()) {
input = lang::Placeholder<bfloat16>(id, shape);
} else if (dtype.is_float16()) {
input = lang::Placeholder<float16>(id, shape);
} else if (dtype.is_bool()) {
input = lang::Placeholder<bool>(id, shape);
} else if (dtype.is_int(8)) {
input = lang::Placeholder<int8_t>(id, shape);
} else if (dtype.is_int(16)) {
input = lang::Placeholder<int16_t>(id, shape);
} else if (dtype.is_int(32)) {
input = lang::Placeholder<int32_t>(id, shape);
} else if (dtype.is_int(64)) {
input = lang::Placeholder<int64_t>(id, shape);
} else if (dtype.is_uint(8)) {
input = lang::Placeholder<uint8_t>(id, shape);
} else if (dtype.is_uint(16)) {
input = lang::Placeholder<uint16_t>(id, shape);
} else if (dtype.is_uint(32)) {
input = lang::Placeholder<uint32_t>(id, shape);
} else if (dtype.is_uint(64)) {
input = lang::Placeholder<uint64_t>(id, shape);
}
tensor_inputs.push_back(input);
cinn_inputs.push_back(common::CINNValue(input));
input_output_nodes.push_back(id);
}
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
// collect output node data name.
std::string out_name = node_data->id();
VLOG(3) << "cinn_inputs.push_back " << out_name;
cinn_inputs.push_back(common::CINNValue(out_name));
out_types.push_back(type_dict_.at(out_name));
out_shapes.push_back(shape_dict_.at(out_name));
input_output_nodes.push_back(out_name);
}
auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](
node->attrs, tensor_inputs, out_types, out_shapes, target_));
auto res = GetFuncFromImpl(impl,
common::CINNValuePack{cinn_inputs},
tensor_inputs,
input_output_nodes,
node->id(),
target_);
return res;
}
std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const Node* node) {
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape");
auto& dtype_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, Type>>("inferdtype");
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<std::vector<int>> output_shapes;
VLOG(3) << "GetOpFunc of op " << node->id();
for (auto& i : node->inlinks_in_order()) {
std::string input_id = i->source()->as<NodeData>()->id();
auto in_shape = shape_dict.at(input_id);
Type dtype = dtype_dict.at(input_id);
CHECK(dtype.is_supported())
<< "The dtype of node " << input_id
<< " is not float or bool or int! Other dtype is not implemented yet.";
ir::Tensor temp;
if (dtype.is_float(32)) {
temp = lang::Placeholder<float>(input_id, in_shape);
} else if (dtype.is_float(64)) {
temp = lang::Placeholder<double>(input_id, in_shape);
} else if (dtype.is_bfloat16()) {
temp = lang::Placeholder<bfloat16>(input_id, in_shape);
} else if (dtype.is_float16()) {
temp = lang::Placeholder<float16>(input_id, in_shape);
} else if (dtype.is_bool()) {
temp = lang::Placeholder<bool>(input_id, in_shape);
} else if (dtype.is_int(8)) {
temp = lang::Placeholder<int8_t>(input_id, in_shape);
} else if (dtype.is_int(16)) {
temp = lang::Placeholder<int16_t>(input_id, in_shape);
} else if (dtype.is_int(32)) {
temp = lang::Placeholder<int32_t>(input_id, in_shape);
} else if (dtype.is_int(64)) {
temp = lang::Placeholder<int64_t>(input_id, in_shape);
} else if (dtype.is_uint(8)) {
temp = lang::Placeholder<uint8_t>(input_id, in_shape);
} else if (dtype.is_uint(16)) {
temp = lang::Placeholder<uint16_t>(input_id, in_shape);
} else if (dtype.is_uint(32)) {
temp = lang::Placeholder<uint32_t>(input_id, in_shape);
} else if (dtype.is_uint(64)) {
temp = lang::Placeholder<uint64_t>(input_id, in_shape);
}
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
std::vector<Type> out_types;
for (auto& out : node->outlinks_in_order()) {
std::string out_id = out->sink()->safe_as<NodeData>()->id();
auto out_shape = shape_dict.at(out_id);
Type dtype = dtype_dict.at(out_id);
output_shapes.push_back(out_shape);
out_types.push_back(dtype);
}
auto impl = OpStrategy::SelectImpl(strategy[node->op()](
node->attrs, inputs, out_types, output_shapes, target_));
common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if ((!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(GetOrGenFullFuncName(GenOpFuncName(node)),
stages,
inputs,
{},
{},
nullptr,
this->target_);
VLOG(3) << "The [" << func.size() << "] functions of node ["
<< node->attrs.node_name << "] are:\n";
for (auto& i : func) {
VLOG(3) << i;
}
return func;
}
// get the most complex op's index in the fused groups according to the
// OpPattern. If the OpPattern is same, we will take the latter.
int GetMasterRefNode(const std::vector<Node*>& nodes) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
int master_index = 0;
int master_pattern = op_pattern_dict[nodes[0]->op()];
for (int i = 1; i < nodes.size(); i++) {
int pattern = op_pattern_dict[nodes[i]->op()];
master_index = pattern >= master_pattern ? i : master_index;
master_pattern = std::max(pattern, master_pattern);
}
VLOG(3) << "master_index: " << master_index
<< ", master op: " << nodes[master_index]->op()->name;
return master_index;
}
std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(
const std::vector<Node*>& nodes) {
CHECK_GT(nodes.size(), 1) << "fuse nodes number must be greater than 1";
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape");
auto& dtype_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, Type>>("inferdtype");
int fuse_number = nodes.size();
VLOG(3) << "fuse begin: " << nodes[0]->id();
std::vector<ir::Tensor> inputs;
std::vector<ir::Tensor> outputs;
poly::StageMap stages;
int index = 0;
std::string fuse_name = "fn_";
std::unordered_set<NodeData*> in_vars;
std::unordered_set<NodeData*> out_vars;
absl::flat_hash_map<NodeData*, Expr> temp_var_map;
absl::flat_hash_set<ir::Tensor> fetch_tensors;
ir::Tensor master_out_tensor;
int master_index = GetMasterRefNode(nodes);
for (auto& node : nodes) {
std::vector<ir::Tensor> temp_inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<std::vector<int>> output_shapes;
fuse_name += node->id() + "_";
for (auto& link : node->inlinks_in_order()) {
auto source = link->source();
CHECK(source);
auto source_data = source->as<NodeData>();
CHECK(source_data);
if (temp_var_map.count(source_data)) {
VLOG(3) << "duplicate var: " << source_data->id();
Expr fuse_out = temp_var_map[source_data];
cinn_inputs.push_back(common::CINNValue(fuse_out));
temp_inputs.push_back(fuse_out.as_tensor_ref());
} else {
std::string input_id = source_data->id();
auto in_shape = shape_dict.at(input_id);
Type dtype = dtype_dict.at(input_id);
CHECK(dtype.is_supported()) << "The dtype of node " << input_id
<< " is not float or bool or int! Other "
"dtype is not implemented yet.";
ir::Tensor temp_in;
if (dtype.is_float(32)) {
temp_in = lang::Placeholder<float>(input_id, in_shape);
} else if (dtype.is_float(64)) {
temp_in = lang::Placeholder<double>(input_id, in_shape);
} else if (dtype.is_bfloat16()) {
temp_in = lang::Placeholder<bfloat16>(input_id, in_shape);
} else if (dtype.is_float16()) {
temp_in = lang::Placeholder<float16>(input_id, in_shape);
} else if (dtype.is_bool()) {
temp_in = lang::Placeholder<bool>(input_id, in_shape);
} else if (dtype.is_int(8)) {
temp_in = lang::Placeholder<int8_t>(input_id, in_shape);
} else if (dtype.is_int(16)) {
temp_in = lang::Placeholder<int16_t>(input_id, in_shape);
} else if (dtype.is_int(32)) {
temp_in = lang::Placeholder<int32_t>(input_id, in_shape);
} else if (dtype.is_int(64)) {
temp_in = lang::Placeholder<int64_t>(input_id, in_shape);
} else if (dtype.is_uint(8)) {
temp_in = lang::Placeholder<uint8_t>(input_id, in_shape);
} else if (dtype.is_uint(16)) {
temp_in = lang::Placeholder<uint16_t>(input_id, in_shape);
} else if (dtype.is_uint(32)) {
temp_in = lang::Placeholder<uint32_t>(input_id, in_shape);
} else if (dtype.is_uint(64)) {
temp_in = lang::Placeholder<uint64_t>(input_id, in_shape);
}
inputs.push_back(temp_in);
temp_inputs.push_back(temp_in);
cinn_inputs.push_back(common::CINNValue(temp_in));
temp_var_map[source_data] = Expr(temp_in);
}
in_vars.insert(source_data);
}
std::vector<Type> out_types;
std::vector<NodeData*> temp_outvars;
for (auto& out : node->outlinks_in_order()) {
auto out_var = out->sink()->safe_as<NodeData>();
CHECK(out_var);
out_vars.insert(out_var);
temp_outvars.push_back(out_var);
std::string out_id = out_var->id();
VLOG(3) << "out_id " << out_id;
auto out_shape = shape_dict.at(out_id);
Type dtype = dtype_dict.at(out_id);
output_shapes.push_back(out_shape);
out_types.push_back(dtype);
}
auto impl = OpStrategy::SelectImpl(strategy[node->op()](
node->attrs, temp_inputs, out_types, output_shapes, target_));
common::CINNValuePack C =
impl->fcompute(common::CINNValuePack{cinn_inputs});
if (index == master_index) {
// use the most complex op's schedule as the fused ops' schedule.
C = impl->fschedule(C);
CHECK(!C.empty());
Expr out = C[0];
master_out_tensor = out.as_tensor_ref();
}
CHECK_GE(C.size(), 2);
std::vector<Expr> temp_C;
if (C.size() - 1 > node->outlinks_in_order().size()) {
for (int i = 1; i < C.size() - 1; i++) {
ir::Expr temp = C[i];
VLOG(1) << "C[" << i << "] name is : " << temp.as_tensor_ref()->name;
outputs.push_back(temp.as_tensor_ref());
}
common::CINNValuePack C_temp{{C[0], C.back()}};
C = C_temp;
}
for (int i = 0; i < C.size() - 1; i++) {
Expr out = C[i];
temp_var_map[temp_outvars[i]] = out;
if (fetch_var_ids_.count(temp_outvars[i]->id())) {
VLOG(3) << "get fetch output var " << temp_outvars[i]->id();
CHECK(out.as_tensor());
fetch_tensors.insert(out.as_tensor_ref());
}
}
CHECK_LE(C.size() - 1, node->outlinks_in_order().size());
poly::StageMap temp_stages = C.back();
for (auto& i : temp_stages) {
auto tensor = ir::Tensor(i.second->tensor());
stages->InsertLazily(tensor, i.second.get());
}
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
CHECK(temp.as_tensor());
auto temp_tensor = temp.as_tensor_ref();
stages->InsertLazily(temp_tensor, temp_stages[temp_tensor]);
if (index < fuse_number - 1 && !temp_tensor->is_reduce_tensor()) {
// assume that only the first out_var links to other op node which will
// compute inline
if (fetch_tensors.count(temp_tensor)) {
VLOG(3) << "add op's fetch out_vars: " << temp_tensor->name;
outputs.insert(outputs.begin(), temp_tensor);
} else if (i == 0) {
VLOG(3) << "inline " << temp_tensor->name;
stages[temp_tensor]->ComputeInline();
} else {
VLOG(3) << "add middle op's other out_vars: " << temp_tensor->name;
outputs.push_back(temp_tensor);
}
} else if (index < fuse_number - 1 && temp_tensor->is_reduce_tensor()) {
VLOG(3) << "temp buffer " << temp_tensor->name;
VLOG(3) << "add op's out_vars: " << temp_tensor->name;
outputs.push_back(temp_tensor);
} else {
if (index == fuse_number - 1) {
// final output tensor
outputs.insert(outputs.begin(), temp_tensor);
} else {
outputs.push_back(temp_tensor);
}
}
}
index++;
}
fuse_name += "fused";
VLOG(3) << "fuse_name: " << fuse_name;
// args order: inputs + final output + fetch outputs + other no_fused outputs
for (auto& tensor : outputs) {
// checkout the tensor is with buffer.
if ((!tensor->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) &&
!stages[tensor]->inlined()) {
inputs.push_back(tensor);
}
}
ir::Tensor final_out_tensor = outputs.front();
if (final_out_tensor->name != master_out_tensor->name) {
if (final_out_tensor->is_reduce_tensor()) {
VLOG(3) << "final_out_tensor is reduce tensor!";
} else {
stages[final_out_tensor]->CopyTransform(stages[master_out_tensor]);
stages[final_out_tensor]->CopyLoopInfo(stages[master_out_tensor]);
}
}
for (auto& s : stages) {
auto& compute_ats = s.second->GetComputeAts();
auto tensor = s.second->tensor();
if (!compute_ats.empty()) {
poly::ComputeAtRelation new_relation;
CHECK_EQ(compute_ats.size(), 1U);
auto new_stage = stages[final_out_tensor];
for (auto& compute_at : compute_ats) {
auto& old_relation = compute_at.second;
auto old_target_tensor = old_relation.stage->tensor();
if (stages[old_target_tensor]->inlined()) {
new_relation.stage = new_stage;
new_relation.level = old_relation.level;
compute_ats.clear();
CHECK(new_relation.IsCompatible(s.second.get()))
<< "new computeAt should be compatible";
compute_ats[new_stage->id()] = new_relation;
break;
}
}
}
}
// deal with fetch tensors, not compute_inline but do compute_at
for (auto& fetch_tensor : fetch_tensors) {
if (fetch_tensor->is_reduce_tensor() ||
fetch_tensor->name == final_out_tensor->name)
continue;
stages[fetch_tensor]->DisableComputeInline();
int level = stages[final_out_tensor]->n_out_dims() - 1;
VLOG(3) << "no fuse fetch tensor " << fetch_tensor->name
<< " and recomputeAt in level " << level;
// if the fetch tensor size is 1, the fetch tensor cannot fuse by ComputeAt2
int len = 1;
for (const auto& dim : fetch_tensor->shape) {
len *= dim.as_int32();
}
if (len <= 1) {
continue;
}
stages[fetch_tensor]->ComputeAt2(stages[final_out_tensor], level);
}
auto func = lang::LowerVec(GetOrGenFullFuncName(fuse_name),
stages,
inputs,
{},
{},
nullptr,
this->target_);
VLOG(3) << "The [" << func.size() << "] functions are:\n";
for (auto& i : func) {
VLOG(3) << "Function [" << i->name << "] is:\n";
VLOG(3) << i;
}
return func;
}
void GraphCompiler::ProcessFunction(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
for (auto&& func : lowered_funcs) {
std::vector<std::string> input_args;
std::vector<std::string> output_args;
for (auto&& arg : func->args) {
std::string arg_name = arg.name();
if (arg_name[0] == '_') arg_name = arg_name.substr(1);
if (arg.io == ir::Argument::IO::kOutput)
output_args.push_back(arg_name);
else if (arg.io == ir::Argument::IO::kInput)
input_args.push_back(arg_name);
auto* var = scope_->FindVar(arg_name);
if (!arg.is_buffer()) {
VLOG(3) << "function:" << func->name << "-argument:" << arg_name
<< " type is not buffer, lowered_func:\n"
<< func;
}
if (!var &&
arg.is_buffer()) { // For argument buffer not in scope, create it.
auto* new_var = scope_->Var<Tensor>(arg_name);
auto& tensor = absl::get<Tensor>(*new_var);
std::vector<Shape::dim_t> shape;
for (auto& shape_dim : arg.buffer_arg()->shape) {
CHECK(shape_dim.is_constant());
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->set_type(arg.buffer_arg()->dtype);
VLOG(3) << utils::StringFormat(
"Will create a new variable in scope for argument[%s] in "
"function[%s] with shape[%s],dtype[%s]",
arg_name.c_str(),
func->name.c_str(),
utils::Join(tensor->shape().data(), ","),
common::Type2Str(tensor->type()));
}
}
function2input_args_[func->name] = input_args;
function2output_args_[func->name] = output_args;
m_builder_.AddFunction(func);
}
}
std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph);
GraphCompiler::CompileOptions options;
......@@ -801,616 +276,57 @@ GraphCompiler::CompilationResult GraphCompiler::Build(
std::unordered_set<std::string>&& fetch_var_ids,
void* stream) {
Context::Global().ResetNameId();
if (FLAGS_cinn_parallel_compile_size) {
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_
: fetch_var_ids);
if (options.with_instantiate_variables) {
VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData",
utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var =
scope_->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) {
auto src_var_name = reuse_vars_map_.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer());
} else {
tensor->mutable_data(target_, tensor->type());
}
}
}
VLOG(2) << "Compile With Parallel Compiler!";
utils::RecordEvent("GraphCompiler CompileResult",
utils::EventType::kOrdinary);
ParallelCompiler::CompileOptions option;
option.lowered_funcs = options.lowered_funcs;
parallel_compiler_ =
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_);
auto instructions = (*parallel_compiler_.get())();
if (options.remove_unused_variables) {
RemoveInvalidVariables(instructions);
}
if (options.with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&instructions);
}
VLOG(2) << "Compile With Parallel Compiler Done!";
GraphCompiler::CompilationResult compilation_result;
compilation_result.runtime_program.reset(
new Program(scope_, std::move(instructions)));
return compilation_result;
}
compile_options_ = options;
fetch_var_ids_ = std::move(fetch_var_ids);
auto topo_order = graph_->topological_order();
auto& nodes = std::get<0>(topo_order);
VLOG(3) << "Begin GraphCompiler::Build";
function2input_args_.clear();
function2output_args_.clear();
m_builder_.Clear();
// if there are no available groups, we will take each node as a group
if (options.groups.empty() && graph_->groups.empty() &&
graph_->fusion_groups.empty()) {
VLOG(3) << "not run opfusion pass";
for (auto& node : nodes) {
auto op_node = node->safe_as<Node>();
if (op_node) {
graph_->groups.push_back({op_node});
}
}
}
// use the input groups in options firstly if exists
std::vector<std::vector<Node*>> groups;
if (options.groups.empty()) {
groups = graph_->groups;
} else {
for (std::shared_ptr<Graph::Group> g : options.groups) {
groups.push_back(g->CollectNodes());
}
}
// if the input lowered_funcs is empty, we will use the default lowering
// process to generate
std::vector<std::vector<ir::LoweredFunc>> local_lowered_funcs;
if (options.lowered_funcs.empty()) {
utils::RecordEvent("GraphCompiler LoweredFuncs",
utils::EventType::kOrdinary);
// lowering of new fusion pass is not compatible with the groups from the
// input options, thus process it separately
if (!graph_->fusion_groups.empty()) {
auto& dtype_dict =
graph_->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict =
graph_->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
for (auto& group : graph_->fusion_groups) {
VLOG(3) << "group_id is : " << group->group_id
<< ", and its number is : " << group->nodes.size();
groups.push_back(std::move(group->CollectNodes()));
local_lowered_funcs.emplace_back(std::move(op_lowerer.Lower(group)));
CHECK_EQ(local_lowered_funcs.back().size(), 1)
<< "Lowered Function Is Not Equal 1!";
VLOG(3) << local_lowered_funcs.back()[0];
}
} else {
VLOG(3) << "fusion_groups is empty";
std::vector<ir::LoweredFunc> lowered_func;
if (FLAGS_cinn_ir_schedule) {
auto& dtype_dict =
graph_->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict =
graph_->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (int i = 0; i < groups.size(); i++) {
for (int j = 0; j < groups[i].size(); j++) {
lowered_func =
GetOpFuncWithIRSchedule(groups[i][j], dtype_dict, shape_dict);
local_lowered_funcs.emplace_back(std::move(lowered_func));
}
}
} else {
for (int i = 0; i < groups.size(); i++) {
if (groups[i].size() == 1) {
lowered_func = GetOpFunc(groups[i][0]);
} else {
lowered_func = GetOpFunc(groups[i]);
}
local_lowered_funcs.emplace_back(std::move(lowered_func));
}
}
}
}
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_->VisualizeGroupedGraph(
groups, fetch_var_ids.empty() ? fetch_var_ids_ : fetch_var_ids);
graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_
: fetch_var_ids);
// use the input lowered_funcs in options firstly if exists
const auto& lowered_funcs = options.lowered_funcs.empty()
? local_lowered_funcs
: options.lowered_funcs;
CHECK_EQ(groups.size(), lowered_funcs.size())
<< "The size of groups and lowered_funcs should be equal";
{
utils::RecordEvent("GraphCompiler ProcessFunction",
utils::EventType::kOrdinary);
for (auto&& lowered_func : lowered_funcs) {
this->ProcessFunction(lowered_func);
}
if (options.with_instantiate_variables) {
InstantiateVariables();
}
// compile the module
// Need to create a new compiler for every call of Build,
// because the underneath jit engine doesn't support addIRModule repeatedly
// now.
compiler_ = backends::Compiler::Create(target_);
VLOG(2) << "Compile With Parallel Compiler!";
utils::RecordEvent("GraphCompiler CompileResult",
utils::EventType::kOrdinary);
ParallelCompiler::CompileOptions option;
option.lowered_funcs = options.lowered_funcs;
auto build_module = m_builder_.Build();
VLOG(3) << "End of m_builder_.Build()";
if (this->target_.arch == Target::Arch::X86) {
utils::RecordEvent("GraphCompiler CodeGenCX86",
utils::EventType::kOrdinary);
CodeGenCX86 codegen(this->target_, CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
auto out = codegen.Compile(build_module, CodeGenC::OutputKind::CImpl);
VLOG(3) << "[X86] C Code is:\n" << out;
}
parallel_compiler_ =
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_);
auto instructions = (*parallel_compiler_.get())();
{
utils::RecordEvent("GraphCompiler BackendsBuild",
utils::EventType::kOrdinary);
compiler_->Build(build_module, options.attached_code);
VLOG(3) << "End of compiler_->Build";
}
auto instructions = BuildInstructions(
groups, options.groups.empty() ? graph_->fusion_groups : options.groups);
VLOG(3) << "End of BuildInstructions";
if (options.remove_unused_variables) {
RemoveInvalidVariables(instructions);
}
if (options.with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&instructions);
}
VLOG(2) << "Compile With Parallel Compiler Done!";
if (options.with_instantiate_variables) {
VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData",
utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) {
auto src_var_name = reuse_vars_map_.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer());
} else {
tensor->mutable_data(target_, tensor->type());
}
}
}
GraphCompiler::CompilationResult result;
result.runtime_program.reset(new Program(scope_, std::move(instructions)));
return result;
}
void GraphCompiler::SetSubKernels(Instruction* instr,
const std::string& func_name) {
int i = 1;
std::string new_op_func = func_name + "_" + std::to_string(i);
if (function2input_args_.count(new_op_func) != 0) {
CHECK_GT(function2input_args_.count(func_name), 0);
instr->AddInArgs(function2input_args_[func_name]);
instr->AddOutArgs(function2output_args_[func_name]);
}
while (function2input_args_.count(new_op_func) != 0) {
auto* fn_ptr = compiler_->Lookup(new_op_func);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), new_op_func);
instr->AddInArgs(function2input_args_[new_op_func]);
instr->AddOutArgs(function2output_args_[new_op_func]);
i++;
new_op_func = func_name + "_" + std::to_string(i);
}
GraphCompiler::CompilationResult compilation_result;
compilation_result.runtime_program.reset(
new Program(scope_, std::move(instructions)));
return compilation_result;
}
void GraphCompiler::BuildCublasInstr(const Node& node,
Instruction* instr) const {
instr->ClearInArgs();
instr->AddInArgs(OpGetInputNames(&node));
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape");
// shape info
std::vector<int> shape_sizes;
for (auto& in_node : node.inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end());
shape_sizes.push_back(in_shape.size());
}
// cublas_gemm has three input vars, and its output shape is equal to the
// input bias. cublas_matmul only has two input vars, so we should get its
// output shape from shape_dict.
if (node.op()->name == "cublas_matmul") {
for (auto& out_node : node.outlinks_in_order()) {
std::string out_id = out_node->sink()->safe_as<NodeData>()->id();
auto out_shape = shape_dict.at(out_id);
instr->attrs.insert(
instr->attrs.end(), out_shape.begin(), out_shape.end());
shape_sizes.push_back(out_shape.size());
}
}
instr->attrs.insert(
instr->attrs.end(), shape_sizes.begin(), shape_sizes.end());
// attribute info
bool trans_a = false;
if (node.attrs.attr_store.contains("trans_a")) {
trans_a = absl::get<bool>(node.attrs.attr_store.at("trans_a"));
}
instr->attrs.push_back(static_cast<int>(trans_a));
bool trans_b = false;
if (node.attrs.attr_store.contains("trans_b")) {
trans_b = absl::get<bool>(node.attrs.attr_store.at("trans_b"));
}
instr->attrs.push_back(static_cast<int>(trans_b));
bool trans_out = false;
if (node.attrs.attr_store.contains("trans_out")) {
trans_out = absl::get<bool>(node.attrs.attr_store.at("trans_out"));
}
instr->attrs.push_back(static_cast<int>(trans_out));
float alpha = 1.f;
if (node.attrs.attr_store.contains("alpha")) {
alpha = absl::get<float>(node.attrs.attr_store.at("alpha"));
}
instr->attrs.push_back(*reinterpret_cast<int*>(&alpha));
}
std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions(
const std::vector<std::vector<Node*>>& groups,
const std::vector<std::shared_ptr<Graph::Group>>& fusion_groups) {
utils::RecordEvent("GraphCompiler BuildInstructions",
utils::EventType::kOrdinary);
std::vector<std::unique_ptr<Instruction>> instructions;
auto topo_order = graph_->topological_order();
auto& nodes = std::get<0>(topo_order);
auto& edges = std::get<1>(topo_order);
VLOG(3) << "Begin GraphCompiler::BuildInstructions";
CHECK_GT(groups.size(), 0);
CHECK_EQ(fusion_groups.size() != 0, groups.size() == fusion_groups.size())
<< "fusion_groups's size must be 0 or equal to groups. Currently "
"fusion_group's size = "
<< fusion_groups.size() << ", group's size = " << groups.size();
for (int idx = 0; idx < groups.size(); ++idx) {
auto& group = groups[idx];
std::shared_ptr<Graph::Group> fusion_group(nullptr);
if (fusion_groups.size()) {
fusion_group = fusion_groups[idx];
}
if (group.size() == 1) {
auto node = group[0];
auto instr_name = node->op()->name;
if (node->op()->name == "reshape" &&
compile_options_.with_instantiate_variables) {
// not run instruction and shares buffer only when instantiate_variables
const auto& inlinks = node->inlinks_in_order();
const auto& outlinks = node->outlinks_in_order();
CHECK_EQ(inlinks.size(), 1U);
CHECK_EQ(outlinks.size(), 1U);
std::string in_id = inlinks[0]->source()->safe_as<NodeData>()->id();
std::string out_id = outlinks[0]->sink()->safe_as<NodeData>()->id();
reuse_vars_map_[out_id] = in_id;
instr_name = "no_run";
}
auto instr = std::unique_ptr<Instruction>(
new Instruction(target_,
scope_.get(),
fusion_group.get() ? fusion_group->input_names
: OpGetInputNames(node),
fusion_group.get() ? fusion_group->output_names
: OpGetOutputNames(node),
instr_name));
if (target_.arch == Target::Arch::NVGPU) {
if (node->op()->name == "conv2d") {
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (auto& in_node : node->inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
instr->attrs.insert(
instr->attrs.end(), in_shape.begin(), in_shape.end());
}
AddAttrs(node->attrs.attr_store,
{"padding", "stride", "dilation"},
instr.get());
if (node->attrs.attr_store.find("groups") !=
node->attrs.attr_store.end()) {
auto conv_groups =
absl::get<int>(node->attrs.attr_store.at("groups"));
instr->attrs.push_back(conv_groups);
} else {
instr->attrs.push_back(1);
}
// output shape
const auto& out_links = node->outlinks_in_order();
CHECK(!out_links.empty());
auto& out_node = out_links.front();
std::string out_id = out_node->sink()->safe_as<NodeData>()->id();
auto out_shape = shape_dict.at(out_id);
instr->attrs.insert(
instr->attrs.end(), out_shape.begin(), out_shape.end());
CHECK_EQ(instr->attrs.size(), 19UL);
// conv type {forward, backward_data, backward_filter}
std::string type = "forward";
if (node->attrs.attr_store.find("conv_type") !=
node->attrs.attr_store.end()) {
type =
absl::get<std::string>(node->attrs.attr_store.at("conv_type"));
}
instr->str_attrs.push_back(type);
if (node->attrs.attr_store.find("data_format") !=
node->attrs.attr_store.end()) {
instr->str_attrs.push_back(
absl::get<std::string>(node->attrs.attr_store["data_format"]));
}
} else if (node->op()->name == "depthwise_conv2d") {
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (auto& in_node : node->inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
instr->attrs.insert(
instr->attrs.end(), in_shape.begin(), in_shape.end());
}
// conv
AddAttrs(node->attrs.attr_store,
{"padding", "stride", "dilation"},
instr.get());
if (node->attrs.attr_store.find("groups") !=
node->attrs.attr_store.end()) {
auto groups = absl::get<int>(node->attrs.attr_store.at("groups"));
instr->attrs.push_back(groups);
} else {
instr->attrs.push_back(instr->attrs[1]);
}
// output shape
const auto& out_links = node->outlinks_in_order();
CHECK(!out_links.empty());
auto& out_node = out_links.front();
std::string out_id = out_node->sink()->safe_as<NodeData>()->id();
auto out_shape = shape_dict.at(out_id);
instr->attrs.insert(
instr->attrs.end(), out_shape.begin(), out_shape.end());
CHECK_EQ(instr->attrs.size(), 19UL);
// conv type {forward, backward_data, backward_filter}
std::string type = "forward";
if (node->attrs.attr_store.find("conv_type") !=
node->attrs.attr_store.end()) {
type =
absl::get<std::string>(node->attrs.attr_store.at("conv_type"));
}
instr->str_attrs.push_back(type);
} else if (node->op()->name == "pool2d") {
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (auto& in_node : node->inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
CHECK_EQ(in_shape.size(), 4UL);
instr->attrs.insert(
instr->attrs.end(), in_shape.begin(), in_shape.end());
}
bool global_pooling = false;
if (node->attrs.attr_store.find("global_pooling") !=
node->attrs.attr_store.end()) {
global_pooling =
absl::get<bool>(node->attrs.attr_store.at("global_pooling"));
}
if (node->attrs.attr_store.find("kernel_size") !=
node->attrs.attr_store.end()) {
if (global_pooling == false) {
auto kernel_size = absl::get<std::vector<int>>(
node->attrs.attr_store.at("kernel_size"));
instr->attrs.insert(
instr->attrs.end(), kernel_size.begin(), kernel_size.end());
} else {
instr->attrs.push_back(instr->attrs[2]);
instr->attrs.push_back(instr->attrs[3]);
}
}
if (node->attrs.attr_store.find("padding_size") !=
node->attrs.attr_store.end()) {
if (global_pooling == false) {
auto padding = absl::get<std::vector<int>>(
node->attrs.attr_store.at("padding_size"));
instr->attrs.insert(
instr->attrs.end(), padding.begin(), padding.end());
if (padding.size() == 2)
instr->attrs.insert(
instr->attrs.end(), padding.begin(), padding.end());
} else {
instr->attrs.push_back(0);
instr->attrs.push_back(0);
instr->attrs.push_back(0);
instr->attrs.push_back(0);
}
}
AddAttrs(node->attrs.attr_store,
{"stride_size", "pool_type"},
instr.get());
for (auto& out_node : node->outlinks_in_order()) {
std::string out_id = out_node->sink()->safe_as<NodeData>()->id();
auto out_shape = shape_dict.at(out_id);
instr->attrs.insert(
instr->attrs.end(), out_shape.begin(), out_shape.end());
}
if (node->attrs.attr_store.find("adaptive") !=
node->attrs.attr_store.end()) {
bool adaptive =
absl::get<bool>(node->attrs.attr_store.at("adaptive"));
if (adaptive)
instr->attrs.push_back(1);
else
instr->attrs.push_back(0);
}
CHECK_EQ(instr->attrs.size(), 17UL);
CHECK_EQ(instr->str_attrs.size(), 1UL);
} else if (node->op()->name == "softmax") {
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (auto& in_node : node->inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
instr->attrs.insert(
instr->attrs.end(), in_shape.begin(), in_shape.end());
}
AddAttrs(node->attrs.attr_store, {"axis"}, instr.get());
} else if (node->op()->name == "mul") {
auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
for (auto& in_node : node->inlinks_in_order()) {
std::string in_id = in_node->source()->safe_as<NodeData>()->id();
auto in_shape = shape_dict.at(in_id);
instr->attrs.insert(
instr->attrs.end(), in_shape.begin(), in_shape.end());
}
if (node->attrs.attr_store.find("x_num_col_dims") !=
node->attrs.attr_store.end()) {
auto axis =
absl::get<int>(node->attrs.attr_store.at("x_num_col_dims"));
instr->attrs.push_back(axis);
} else {
instr->attrs.push_back(1);
}
if (node->attrs.attr_store.find("y_num_col_dims") !=
node->attrs.attr_store.end()) {
auto axis =
absl::get<int>(node->attrs.attr_store.at("y_num_col_dims"));
instr->attrs.push_back(axis);
} else {
instr->attrs.push_back(1);
}
} else if (node->op()->name == "cublas_gemm" ||
node->op()->name == "cublas_matmul") {
BuildCublasInstr(*node, instr.get());
}
}
std::string op_func_name =
fusion_group.get() ? fusion_group->GetFuncName()
: GetOrGenFullFuncName(GenOpFuncName(node));
auto* fn_ptr = compiler_->Lookup(op_func_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), op_func_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
SetSubKernels(instr.get(), op_func_name);
if (node->attrs.attr_store.count("pre_run")) {
instr->pre_run = absl::get<bool>(node->attrs.attr_store["pre_run"]);
}
// explicitly call Finalize of the instruction after all assignments on it
// were done
instr->Finalize();
instructions.push_back(std::move(instr));
void GraphCompiler::InstantiateVariables() {
VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) {
auto src_var_name = reuse_vars_map_.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer());
} else {
CHECK_GT(group.size(), 1U) << "fuse number should be greater than 1";
std::vector<std::string> inputNames;
std::vector<std::string> outputNames;
std::unordered_set<std::string> names_set;
int count = 0;
std::string fuse_name = "fn_";
if (!fusion_group.get()) {
for (int i = 0; i < group.size(); i++) {
auto node = group[i];
CHECK(node);
fuse_name += node->id() + "_";
auto temp_inputnames = OpGetInputNames(node);
for (int j = 0; j < temp_inputnames.size(); j++) {
if (!names_set.count(temp_inputnames[j])) {
inputNames.push_back(temp_inputnames[j]);
names_set.insert(temp_inputnames[j]);
}
}
auto temp_outputnames = OpGetOutputNames(node);
// fused output arg order: final output, ops no_fused outputs
for (int j = 0; j < temp_outputnames.size(); j++) {
if (!names_set.count(temp_outputnames[j])) {
names_set.insert(temp_outputnames[j]);
// assume that the first out_var of the op node is the fused var
bool is_fetch = fetch_var_ids_.count(temp_outputnames[j]);
if (j == 0 && i != group.size() - 1 && !is_fetch) continue;
if (j == 0 && i == group.size() - 1) {
outputNames.insert(outputNames.begin(), temp_outputnames[0]);
} else if (is_fetch) {
VLOG(3) << "fetch var " << temp_outputnames[j];
outputNames.insert(outputNames.begin(), temp_outputnames[j]);
} else {
outputNames.push_back(temp_outputnames[j]);
}
}
}
}
fuse_name += "fused";
VLOG(3) << "In buildInstructions, fuse_name is : " << fuse_name;
VLOG(3) << "input_names: " << utils::Join(inputNames, ", ");
VLOG(3) << "out_names: " << utils::Join(outputNames, ", ");
}
fuse_name = fusion_group.get() ? fusion_group->GetFuncName()
: GetOrGenFullFuncName(fuse_name);
auto instr = std::unique_ptr<Instruction>(new Instruction(
target_,
scope_.get(),
fusion_group.get() ? fusion_group->input_names : inputNames,
fusion_group.get() ? fusion_group->output_names : outputNames,
fuse_name));
auto* fn_ptr = compiler_->Lookup(fuse_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), fuse_name);
// As some situation like reduce,will generate more than one kernel.
// So try to find the rest kernel, if it exists.
SetSubKernels(instr.get(), fuse_name);
for (int j = 0; j < group.size(); j++) {
auto node = group[j];
if (node->attrs.attr_store.count("pre_run") &&
absl::get<bool>(node->attrs.attr_store["pre_run"]) == true) {
instr->pre_run = true;
}
}
// explicitly call Finalize of the instruction after all assignments on it
// were done
instr->Finalize();
instructions.push_back(std::move(instr));
tensor->mutable_data(target_, tensor->type());
}
}
return instructions;
}
void GraphCompiler::RemoveInvalidVariables(
......@@ -1576,40 +492,6 @@ void GraphCompiler::InsertBufferHandlers(
instructions->swap(results);
}
std::vector<std::string> GraphCompiler::OpGetInputNames(
const Node* node) const {
std::vector<std::string> res;
if (node->op()->name == "cublas_gemm" ||
node->op()->name == "cublas_matmul" || node->op()->name == "conv2d" ||
node->op()->name == "depthwise_conv2d" || node->op()->name == "pool2d" ||
node->op()->name == "softmax" || node->op()->name == "mul" ||
node->op()->name == "matmul") {
for (auto& i : node->inlinks_in_order()) {
res.push_back(i->source()->as<NodeData>()->id());
}
} else {
std::unordered_set<std::string> repeat;
for (auto& inode : node->inlinks_in_order()) {
auto id = inode->source()->as<NodeData>()->id();
if (repeat.count(id)) {
continue;
}
repeat.insert(id);
res.push_back(id);
}
}
return res;
}
std::vector<std::string> GraphCompiler::OpGetOutputNames(
const Node* node) const {
std::vector<std::string> res;
for (auto& i : node->outlinks_in_order()) {
res.push_back(i->sink()->as<NodeData>()->id());
}
return res;
}
std::shared_ptr<Scope> BuildScope(Target target,
const std::shared_ptr<Graph>& graph,
std::shared_ptr<Scope> scope) {
......
......@@ -98,10 +98,7 @@ class GraphCompiler final {
GraphCompiler(Target target,
const std::shared_ptr<Scope>& scope,
const std::shared_ptr<Graph>& graph)
: target_(std::move(target)),
scope_(scope),
graph_(graph),
m_builder_(UniqName("module"), target) {}
: target_(std::move(target)), scope_(scope), graph_(graph) {}
struct CompilationResult {
std::unique_ptr<Program> runtime_program;
......@@ -127,44 +124,15 @@ class GraphCompiler final {
CompilationResult Build(const CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids = {},
void* stream = nullptr);
void ExportObject(const std::string& path) { compiler_->ExportObject(path); }
std::unique_ptr<Program> Build(const std::string& code = "");
std::string GenSourceCode();
void PrintFunc();
const std::shared_ptr<Scope>& GetScope() const { return scope_; }
private:
std::vector<ir::LoweredFunc> GetOpFunc(const std::vector<Node*>& nodes);
std::vector<ir::LoweredFunc> GetOpFunc(const Node* node);
// Given a node, lower it to LoweredFunc using new ir schedule
std::vector<ir::LoweredFunc> GetOpFuncWithIRSchedule(
const Node* node,
const absl::flat_hash_map<std::string, Type>& type_dict_,
const absl::flat_hash_map<std::string, shape_t>& shape_dict_);
std::string GenOpFuncName(const Node* node) const {
return "fn_" + node->id();
}
// instantiate all variables on compile time
void InstantiateVariables();
// append a unique number at the end of the function name to distinguish
// different functions from graphs whose structures are same
const std::string& GetOrGenFullFuncName(const std::string& prefix);
// TODO(haozech) add implementation
std::vector<std::string> OpGetInputNames(const Node* node) const;
// TODO(haozech) add implementation
std::vector<std::string> OpGetOutputNames(const Node* node) const;
std::vector<std::unique_ptr<Instruction>> BuildInstructions(
const std::vector<std::vector<Node*>>& groups,
const std::vector<std::shared_ptr<Graph::Group>>& fusion_groups);
void BuildCublasInstr(const Node& node, Instruction* instr) const;
// some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation
......@@ -189,28 +157,16 @@ class GraphCompiler final {
// parallel compiler
std::shared_ptr<ParallelCompiler> parallel_compiler_;
void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs);
void SetSubKernels(Instruction* instr, const std::string& func_name);
Target target_;
std::shared_ptr<Graph> graph_;
std::shared_ptr<Scope> scope_;
// mapping a function's name to its input artuments' names
std::map<std::string, std::vector<std::string>> function2input_args_;
// mapping a function's name to its output artuments' names
std::map<std::string, std::vector<std::string>> function2output_args_;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std::unordered_set<std::string> fetch_var_ids_;
absl::flat_hash_map<std::string, std::string> prefix2full_namemap_;
// map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map_;
std::unique_ptr<backends::Compiler> compiler_;
CompileOptions compile_options_;
ir::Module::Builder m_builder_;
CINN_DISALLOW_COPY_AND_ASSIGN(GraphCompiler);
};
......
......@@ -384,21 +384,7 @@ void BindFrontend(pybind11::module *m) {
program->ExecuteTest(repeat_);
auto out = scope->GetTensor(tensor_out->id);
return out;
})
.def("test_generate_code",
[](Program &self,
const common::Target &target,
const std::vector<Variable> &tensor_inputs,
const std::vector<py::array> &input_data,
const Variable &tensor_out) {
std::shared_ptr<hlir::framework::Graph> g(
new hlir::framework::Graph(self, target));
hlir::framework::ApplyPass(g.get(), "InferShape");
std::shared_ptr<hlir::framework::Scope> scope =
hlir::framework::BuildScope(target, g);
hlir::framework::GraphCompiler gc(target, scope, g);
return gc.GenSourceCode();
});
});
py::class_<frontend::Interpreter>(*m, "Interpreter")
.def(py::init<const std::vector<std::string> &,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册