From 14393611a7281862fd5b7930b3e820d787ee577b Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 18 Aug 2023 21:40:15 +0800 Subject: [PATCH] [NewIR]Add builtin.split op (#56431) * [prim][newir] add basic framework for primitive * support desctensor in new ir * add vjp interface * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * [prim][newir] add basic framework for primitive * support desctensor in new ir * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * add eager and static backend for warp lower level api * support call_vjp pybind * polish code and add test for vjp * remove useless code * polish code * remove useless code * support mean vjp * backward origin code * add test for mean vjp and support has_vjp function * fix call_vjp * polish code * add attrs and dtype interface * add primitive ops set for backend * fix compile bugs * fix some bugs * fix windows bugs * add vjp test for tanh_ * fix inference CI * fix inference ci * modify fluid cmake * origin test of tanh and mean passed * fix conflict * modify stop_gradient * remove useless deps * add cmake * modify block.ops * modify test * fix conflict * reply review comments * reply review comments * pulish code * fix comment * fix test * polish code * modify backward stop_gradients * modify static_backend.cc * refactor grad_op * support add and add_inplace vjp * remove useless code * remove useless code * remove cout * modify add_n * modify add_n with add_vjp test * modify add_n with add_vjp test * fix conflict and concat call_vjp * modify backward test * Add more gen api * add builtin split op --------- Co-authored-by: cxxly Co-authored-by: Charles-hit Co-authored-by: zhangbo9674 Co-authored-by: YuanRisheng Co-authored-by: 0x45f --- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 34 ++++++++- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 2 +- .../ir/transforms/pd_op_to_kernel_pass.cc | 72 ++++++++++++++++++- paddle/ir/core/builtin_dialect.cc | 1 + paddle/ir/core/builtin_op.cc | 41 +++++++++++ paddle/ir/core/builtin_op.h | 43 ++++++++++- 6 files changed, 187 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 4b225201b9a..cadd9a29519 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -54,7 +54,10 @@ void AddNewData(ir::Value value, std::string>* variable_2_var_name, std::map* var_name_2_id, std::vector* variable_list) { - value_2_var_name->emplace(value, name); + if (value_2_var_name->count(value) == 0) { + value_2_var_name->emplace(value, name); + } + variable_2_var_name->emplace(var, name); if (var_name_2_id->count(name) == 0) { auto id = var_name_2_id->size(); @@ -174,7 +177,6 @@ void BuildValue(ir::Value value, var_name_2_id, variable_list); } - // Only support DenseTensor or Vector if (!value.type()) { var->GetMutable(); @@ -200,6 +202,7 @@ void BuildValue(ir::Value value, variable_2_var_name, var_name_2_id, variable_list); + var_i->GetMutable(); tensor_array->emplace_back(var_i); } @@ -412,6 +415,30 @@ void HandleForSpecialOp( std::string var_name = variable_2_var_name->at(variable_array[index]); value_2_var_name->emplace(out_value, var_name); } + + if (op_name == "builtin.split") { + VLOG(6) << "Handle for builtin.split"; + auto in_value = op->operand_source(0); + PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value), + true, + phi::errors::PreconditionNotMet( + "input of buildin split not in name map")); + + auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value)); + auto variable_array = in_var->Get(); + + for (uint64_t idx = 0; idx < variable_array.size(); ++idx) { + auto out_value = op->result(idx); + PADDLE_ENFORCE_EQ( + variable_2_var_name->count(variable_array[idx]), + true, + phi::errors::PreconditionNotMet("[%d] the variable in build split " + "input MUST in variable name map", + idx)); + std::string var_name = variable_2_var_name->at(variable_array[idx]); + value_2_var_name->emplace(out_value, var_name); + } + } } void HandleForInplaceOp( @@ -498,7 +525,8 @@ void BuildScope(const ir::Block& block, if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "pd.data" || op_name == "pd.shadow_output") { + op_name == "builtin.split" || op_name == "pd.data" || + op_name == "pd.shadow_output") { HandleForSpecialOp(op, inner_scope, var_name_prefix, diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index f59b8d927cb..2b024f47868 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -311,7 +311,7 @@ void BuildPhiContext(ir::Operation* op, ->Get())))); } else if (out_type.isa()) { OutListType outputs; - auto& variable_array = scope->FindVar(name_map.at(out_ptr)) + auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr)) ->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { outputs.emplace_back(OutType(const_cast( diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 027bd915236..762d2cf4b60 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -54,6 +54,7 @@ const std::unordered_set UnchangeOutputOps = { "pd.data", "builtin.combine", "builtin.slice", + "builtin.split", "pd.feed", "pd.fetch", "builtin.set_parameter", @@ -523,7 +524,76 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, op_output_types.push_back(allocated_dense_tensor_dtype); } else { PADDLE_THROW(phi::errors::Unimplemented( - "builtin.combine Result type only support DenseTensorType")); + "builtin.slice Result type only support DenseTensorType")); + } + } + } + // Get op info + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + ir::Operation* op = ir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + map_op_pair[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + map_value_pair[op_item->result(i)] = op->result(i); + } + } + VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); + continue; + } + + if (op_item->name() == "builtin.split") { + phi::Place out_place = place; + // Copy op inputs + std::vector vec_inputs; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", + i, + op_item->name())); + auto new_in = map_value_pair.at(cur_in); + vec_inputs.push_back(new_in); + + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + out_place = + vec_types[0] + .dyn_cast() + .place(); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } + } + } + // Copy op output type + std::vector op_output_types; + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + auto result_type = op_item->result(i).type(); + if (!result_type) { + op_output_types.push_back(result_type); + } else if (result_type.isa()) { + auto allocated_dense_tensor_dtype = + paddle::dialect::AllocatedDenseTensorType::get( + ctx, + out_place, + result_type.dyn_cast()); + op_output_types.push_back(allocated_dense_tensor_dtype); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "builtin.split Result type only support DenseTensorType")); } } } diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 3284a96c8b5..375bf90d2b8 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -55,6 +55,7 @@ void BuiltinDialect::initialize() { SetParameterOp, CombineOp, SliceOp, + SplitOp, ConstantOp>(); } diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 8aff2f1f190..cee89a04f61 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -207,6 +207,46 @@ void SliceOp::Verify() const { output_type); } +void SplitOp::Verify() const { + // inputs.size() == 1 + IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1."); + + // input_type == Vector + auto input_type = (*this)->operand(0).type().dyn_cast(); + IR_ENFORCE(input_type, "The type of inputs[0] must be equal to VectorType."); + + // inputs[0].size() == outputs.size() + auto output_num = num_results(); + IR_ENFORCE(input_type.size() == output_num, + "The size %d of output must be equal to size %d of inputs.", + output_num, + input_type.size()); + + // for all i in outputs.size(): outputs[i].type == inputs[0][i].type + for (size_t i = 0; i < output_num; ++i) { + auto type = (*this)->result(i).type(); + IR_ENFORCE(input_type[i] == type, + "The type %s of inputs[0][%d] must be " + "equal to type %s of outputs[%d].", + input_type[i], + i, + type, + i); + } +} + +void SplitOp::Build(Builder &builder, + OperationArgument &argument, + const ir::OpResult &input) { + argument.inputs = {input}; + std::vector outputs_types; + for (size_t idx = 0; idx < input.type().dyn_cast().size(); + ++idx) { + argument.output_types.emplace_back( + input.type().dyn_cast()[idx]); + } +} + const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT void ConstantOp::Build(Builder &builder, @@ -232,5 +272,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::SplitOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp) diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index fe5b7116a29..ae1b748bd33 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -93,6 +93,13 @@ class IR_API CombineOp : public ir::Op { const std::vector &inputs); void Verify() const; + std::vector inputs() { + std::vector inputs; + for (uint32_t idx = 0; idx < num_operands(); idx++) { + inputs.push_back(operand_source(static_cast(idx))); + } + return inputs; + } ir::OpResult out() { return result(0); } }; @@ -108,8 +115,41 @@ class IR_API SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const ir::OpResult &input); + void Verify() const; - ir::OpResult out() { return result(0); } + ir::Value input() { return operand_source(0); } +}; + +/// +/// \brief SplitOp: SplitOp(OpOperand) +/// +class IR_API SplitOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.split"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const ir::OpResult &input); + + void Verify() const; + ir::Value input() { return operand_source(0); } + std::vector outputs() { + std::vector outputs; + for (uint32_t idx = 0; idx < num_results(); idx++) { + outputs.push_back(result(static_cast(idx))); + } + return outputs; + } }; class IR_API ConstantLikeTrait : public OpTraitBase { @@ -146,5 +186,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::GetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SliceOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SplitOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantOp) -- GitLab