From 9358b4bc7dd02dcab517816a84bf6e47b0754805 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:25:25 +0800 Subject: [PATCH] [IR] Refactor code in pd_op_to_kernel_pass.cc and reset vlog level (#57054) * fix merge bug * fix codestyle --- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 24 +- .../ir/transforms/pd_op_to_kernel_pass.cc | 331 +++++++++--------- 2 files changed, 175 insertions(+), 180 deletions(-) diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 7e7b0dbe76b..c72641046f5 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -253,7 +253,8 @@ void HandleForSpecialOp( variable_list); } - if (op_name == "pd.feed") { + if (op_name == "pd.feed" || op_name == "pd.data") { + VLOG(6) << "Handle for" << op_name; auto value = op->result(0); VLOG(6) << "link feed output to feed in variable" << inner_scope; @@ -273,27 +274,6 @@ void HandleForSpecialOp( variable_list); } - if (op_name == "pd.data") { - VLOG(6) << "Handle for pd.data"; - auto var_name = - op->attributes().at("name").dyn_cast().AsString(); - - auto value = op->result(0); - - paddle::framework::Variable* var = inner_scope->FindVar(var_name); - PADDLE_ENFORCE(var, - paddle::platform::errors::InvalidArgument( - "The variable %s shoud exist", var_name)); - - AddNewData(value, - var_name, - var, - value_2_var_name, - variable_2_var_name, - var_name_2_id, - variable_list); - } - if (op_name == "builtin.combine") { auto out_value = op->result(0); diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index ef954c376e4..618e8997796 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -62,6 +62,166 @@ const std::unordered_set UnchangeOutputOps = { "builtin.get_parameter", "pd.shadow_output"}; +const std::unordered_set SpecialOpList = { + "builtin.combine", "builtin.slice", "builtin.split"}; + +ir::OpResult GetNewInput( + const ir::Value cur_in, + const std::unordered_map& map_value_pair, + const int index, + const std::string op_name) { + PADDLE_ENFORCE_EQ( + map_value_pair.count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST be in map pair", index, op_name)); + auto new_in = map_value_pair.at(cur_in); + return new_in; +} + +void DealWithSpecialBuiltinOps( + ir::Operation* op_item, + ir::Program* program, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair, + ir::IrContext* ctx) { + if (op_item->name() == "builtin.combine") { + std::vector out_places; + // Copy op inputs + std::vector vec_inputs; + std::vector vec_inner_types; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); + vec_inner_types.push_back(new_in.type()); + if (new_in.type().isa()) { + out_places.push_back( + new_in.type() + .dyn_cast() + .place()); + } else if (new_in.type() + .isa()) { + out_places.push_back( + new_in.type() + .dyn_cast() + .place()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support dense tensor type for now")); + } + } + } + // Copy op output type + std::vector op_output_types; + ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types); + op_output_types.push_back(t1); + + // Get op info + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + ir::Operation* op = ir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + (*map_op_pair)[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + } + + if (op_item->name() == "builtin.slice") { + std::vector vec_inputs; + std::vector op_output_types; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + auto index = op_item->attributes() + .at("index") + .dyn_cast() + .data(); + op_output_types.push_back(vec_types[index]); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } + } + } + + // Get op info + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + ir::Operation* op = ir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + (*map_op_pair)[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + } + + if (op_item->name() == "builtin.split") { + std::vector out_places(op_item->num_results()); + // Copy op inputs + std::vector vec_inputs; + std::vector op_output_types; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); + + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + for (uint64_t idx = 0; idx < vec_types.size(); idx++) { + op_output_types.push_back(vec_types[idx]); + } + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } + } + } + + // Get op info + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + ir::Operation* op = ir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + (*map_op_pair)[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + } + VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); +} + bool NeedFallBackCpu(const ir::Operation* op, const std::string& kernel_fn_name, const phi::KernelKey& kernel_key) { @@ -620,6 +780,11 @@ phi::KernelKey GetKernelKey( std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, phi::Place place) { + if (VLOG_IS_ON(2)) { + std::stringstream ss; + prog->Print(ss); + VLOG(2) << "Program after lowering to kernel pass : " << ss.str(); + } auto program = std::make_unique(ir::IrContext::Instance()); auto block = prog->block(); @@ -647,163 +812,9 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, continue; } - if (op_item->name() == "builtin.combine") { - std::vector out_places; - // Copy op inputs - std::vector vec_inputs; - std::vector vec_inner_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), - true, - phi::errors::PreconditionNotMet( - "[%d]'s input of [%s] op MUST in map pair", - i, - op_item->name())); - auto new_in = map_value_pair.at(cur_in); - vec_inputs.push_back(new_in); - vec_inner_types.push_back(new_in.type()); - if (new_in.type().isa()) { - out_places.push_back( - new_in.type() - .dyn_cast() - .place()); - } else if (new_in.type() - .isa()) { - out_places.push_back( - new_in.type() - .dyn_cast() - .place()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support dense tensor type for now")); - } - } - } - // Copy op output type - std::vector op_output_types; - ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types); - op_output_types.push_back(t1); - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - map_op_pair[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - map_value_pair[op_item->result(i)] = op->result(i); - } - } - VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); - continue; - } - - if (op_item->name() == "builtin.slice") { - std::vector vec_inputs; - std::vector op_output_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), - true, - phi::errors::PreconditionNotMet( - "[%d]'s input of [%s] op MUST in map pair", - i, - op_item->name())); - auto new_in = map_value_pair.at(cur_in); - vec_inputs.push_back(new_in); - - if (new_in.type().isa()) { - auto vec_types = new_in.type().dyn_cast().data(); - auto index = op_item->attributes() - .at("index") - .dyn_cast() - .data(); - op_output_types.push_back(vec_types[index]); - } else { - PADDLE_THROW( - phi::errors::Unimplemented("only support vector type for now")); - } - } - } - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - map_op_pair[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - map_value_pair[op_item->result(i)] = op->result(i); - } - } - VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); - continue; - } - - if (op_item->name() == "builtin.split") { - std::vector out_places(op_item->num_results()); - // Copy op inputs - std::vector vec_inputs; - std::vector op_output_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), - true, - phi::errors::PreconditionNotMet( - "[%d]'s input of [%s] op MUST in map pair", - i, - op_item->name())); - auto new_in = map_value_pair.at(cur_in); - vec_inputs.push_back(new_in); - - if (new_in.type().isa()) { - auto vec_types = new_in.type().dyn_cast().data(); - for (uint64_t idx = 0; idx < vec_types.size(); idx++) { - op_output_types.push_back(vec_types[idx]); - } - } else { - PADDLE_THROW( - phi::errors::Unimplemented("only support vector type for now")); - } - } - } - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - map_op_pair[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - map_value_pair[op_item->result(i)] = op->result(i); - } - } - VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); + if (SpecialOpList.count(op_item->name())) { + DealWithSpecialBuiltinOps( + op_item, program.get(), &map_op_pair, &map_value_pair, ctx); continue; } @@ -1167,7 +1178,11 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, } } } - + if (VLOG_IS_ON(2)) { + std::stringstream ss1; + program->Print(ss1); + VLOG(2) << "Program after lowering to kernel pass : " << ss1.str(); + } return program; } -- GitLab