未验证 提交 9358b4bc 编写于 作者: C chen2016013 提交者: GitHub

[IR] Refactor code in pd_op_to_kernel_pass.cc and reset vlog level (#57054)

* fix merge bug

* fix codestyle
上级 14f00dc5
...@@ -253,7 +253,8 @@ void HandleForSpecialOp( ...@@ -253,7 +253,8 @@ void HandleForSpecialOp(
variable_list); variable_list);
} }
if (op_name == "pd.feed") { if (op_name == "pd.feed" || op_name == "pd.data") {
VLOG(6) << "Handle for" << op_name;
auto value = op->result(0); auto value = op->result(0);
VLOG(6) << "link feed output to feed in variable" << inner_scope; VLOG(6) << "link feed output to feed in variable" << inner_scope;
...@@ -273,27 +274,6 @@ void HandleForSpecialOp( ...@@ -273,27 +274,6 @@ void HandleForSpecialOp(
variable_list); variable_list);
} }
if (op_name == "pd.data") {
VLOG(6) << "Handle for pd.data";
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->result(0);
paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", var_name));
AddNewData(value,
var_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
if (op_name == "builtin.combine") { if (op_name == "builtin.combine") {
auto out_value = op->result(0); auto out_value = op->result(0);
......
...@@ -62,6 +62,166 @@ const std::unordered_set<std::string> UnchangeOutputOps = { ...@@ -62,6 +62,166 @@ const std::unordered_set<std::string> UnchangeOutputOps = {
"builtin.get_parameter", "builtin.get_parameter",
"pd.shadow_output"}; "pd.shadow_output"};
const std::unordered_set<std::string> SpecialOpList = {
"builtin.combine", "builtin.slice", "builtin.split"};
ir::OpResult GetNewInput(
const ir::Value cur_in,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const int index,
const std::string op_name) {
PADDLE_ENFORCE_EQ(
map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST be in map pair", index, op_name));
auto new_in = map_value_pair.at(cur_in);
return new_in;
}
void DealWithSpecialBuiltinOps(
ir::Operation* op_item,
ir::Program* program,
std::unordered_map<ir::Operation*, ir::Operation*>* map_op_pair,
std::unordered_map<ir::Value, ir::OpResult>* map_value_pair,
ir::IrContext* ctx) {
if (op_item->name() == "builtin.combine") {
std::vector<phi::Place> out_places;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else if (new_in.type()
.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}
if (op_item->name() == "builtin.slice") {
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}
if (op_item->name() == "builtin.split") {
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
op_output_types.push_back(vec_types[idx]);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
}
bool NeedFallBackCpu(const ir::Operation* op, bool NeedFallBackCpu(const ir::Operation* op,
const std::string& kernel_fn_name, const std::string& kernel_fn_name,
const phi::KernelKey& kernel_key) { const phi::KernelKey& kernel_key) {
...@@ -620,6 +780,11 @@ phi::KernelKey GetKernelKey( ...@@ -620,6 +780,11 @@ phi::KernelKey GetKernelKey(
std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
phi::Place place) { phi::Place place) {
if (VLOG_IS_ON(2)) {
std::stringstream ss;
prog->Print(ss);
VLOG(2) << "Program after lowering to kernel pass : " << ss.str();
}
auto program = std::make_unique<ir::Program>(ir::IrContext::Instance()); auto program = std::make_unique<ir::Program>(ir::IrContext::Instance());
auto block = prog->block(); auto block = prog->block();
...@@ -647,163 +812,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, ...@@ -647,163 +812,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
continue; continue;
} }
if (op_item->name() == "builtin.combine") { if (SpecialOpList.count(op_item->name())) {
std::vector<phi::Place> out_places; DealWithSpecialBuiltinOps(
// Copy op inputs op_item, program.get(), &map_op_pair, &map_value_pair, ctx);
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else if (new_in.type()
.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}
if (op_item->name() == "builtin.slice") {
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}
if (op_item->name() == "builtin.split") {
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
op_output_types.push_back(vec_types[idx]);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue; continue;
} }
...@@ -1167,7 +1178,11 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, ...@@ -1167,7 +1178,11 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
} }
} }
} }
if (VLOG_IS_ON(2)) {
std::stringstream ss1;
program->Print(ss1);
VLOG(2) << "Program after lowering to kernel pass : " << ss1.str();
}
return program; return program;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册