未验证 提交 0c7fdda9 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix dialect lower bug (#56130)

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* refine code

* fix bug
上级 c00320c5
......@@ -1537,7 +1537,12 @@ void NewIRInterpreter::BuildInstruction() {
size_t op_idx = 0;
for (auto& op : *ir_program_->block()) {
VLOG(6) << "Build Instruction for op: " << op_idx;
if (op->dialect()->name() == "pd_kernel") {
if (op->dialect()->name() == "builtin") {
if (interpreter::GetSpecialOpNames().count(op->name())) {
VLOG(6) << "skip process " << op->name();
continue;
}
} else if (op->dialect()->name() == "pd_kernel") {
auto op_name = op->attributes()
.at("op_name")
.dyn_cast<::ir::StrAttribute>()
......@@ -1546,6 +1551,7 @@ void NewIRInterpreter::BuildInstruction() {
VLOG(6) << "skip process " << op_name;
continue;
}
VLOG(6) << "process " << op_name;
if (op_name == "pd.fused_softmax_mask_upper_triangle" ||
op_name == "pd.fused_softmax_mask_upper_triangle_grad") {
......@@ -1571,7 +1577,7 @@ void NewIRInterpreter::BuildInstruction() {
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd_kernel dialect."));
"Now only support pd or pd_kernel dialect."));
}
}
}
......
......@@ -364,14 +364,180 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
std::unordered_map<ir::Operation*, ir::Operation*> map_op_pair;
std::unordered_map<ir::Value, ir::OpResult> map_value_pair;
std::string op_name = paddle::dialect::PhiKernelOp::name();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
std::string phi_kernel_op_name = paddle::dialect::PhiKernelOp::name();
ir::OpInfo phi_kernel_op_info = ctx->GetRegisteredOpInfo(phi_kernel_op_name);
for (auto op_item : *block) {
VLOG(6) << "op name " << op_item->name();
if (op_item->name() == "builtin.combine") {
std::vector<phi::Place> out_places;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<ir::VectorType>()) {
std::vector<ir::Type> vec_inner_types;
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
for (size_t idx = 0; idx < base_types.size(); idx++) {
auto& base_type = base_types[idx];
if (base_type) {
if (base_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_places[idx],
base_type.dyn_cast<dialect::DenseTensorType>());
vec_inner_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor in vector type for now"));
}
} else {
// NOTE(phlrain), kernel not support a nullptr in output
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{}};
size_t offset = 0;
auto dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
vec_inner_types.push_back(dense_tensor_dtype);
}
}
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.combine Result type only support "
"VectorType<DenseTensorType>"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}
if (op_item->name() == "builtin.slice") {
phi::Place out_place = place;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
out_place =
vec_types[op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data()]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.combine Result type only support DenseTensorType"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}
// Lower from PaddleDialect to KernelDialect
paddle::dialect::OpYamlInfoInterface op_info_interface =
op_item->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
std::unique_ptr<OpYamlInfoParser> op_info_parser(nullptr);
if (op_info_interface) {
op_info_parser =
......@@ -399,7 +565,6 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
// need update new kernel key layout and data tyep
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
auto phi_kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN(
kernel_fn_str, kernel_key);
......@@ -484,7 +649,6 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
// constuct input
std::vector<ir::OpResult> vec_inputs;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
......@@ -563,7 +727,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_attribute, op_output_types, op_info);
vec_inputs, op_attribute, op_output_types, phi_kernel_op_info);
map_op_pair[op_item] = op;
......@@ -593,8 +757,8 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
phi::TransToPhiPlace(shadow_key.backend()),
op_item->result(0).type().dyn_cast<dialect::DenseTensorType>());
ir::Operation* shadow_op =
ir::Operation::Create({op->result(0)}, attr_map, {out_type}, op_info);
ir::Operation* shadow_op = ir::Operation::Create(
{op->result(0)}, attr_map, {out_type}, phi_kernel_op_info);
map_op_pair[op_item] = shadow_op;
program->block()->push_back(shadow_op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册