diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index 4b9dd25d67e00325c069703898a0271a029ee7f7..e07075a2c026cf9931dc9c0fe391bba553d3e05f 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -52,8 +52,11 @@ void PaddleDialect::initialize() { RegisterOps(); + RegisterInterfaces(); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 599e5c956450745e25e56a79b98e2525bccdf373..058a08a384d2db6b8e0ee380f5d2e9c8ec29bd85 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -24,6 +24,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/infermeta/multiary.h" namespace paddle { @@ -409,6 +410,442 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +const char *FusedGemmEpilogueOp::attributes_name[3] = { + "trans_x", "trans_y", "activation"}; + +OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false), + paddle::dialect::OpInputInfo( + "y", "paddle::dialect::DenseTensorType", false, false, false, false), + paddle::dialect::OpInputInfo("bias", + "paddle::dialect::DenseTensorType", + false, + false, + false, + false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("activation", "ir::StrAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorType", false, false), + paddle::dialect::OpOutputInfo( + "reserve_space", "paddle::dialect::DenseTensorType", true, false)}; + paddle::dialect::OpRunTimeInfo run_time_info( + "FusedGemmEpilogueInferMeta", + {"x", "y", "bias", "trans_x", "trans_y", "activation"}, + {""}, + {""}, + {""}, + {}, + {}, + {}); + + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue"); +} + +void FusedGemmEpilogueOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult x_, + ir::OpResult y_, + ir::OpResult bias_, + ir::AttributeMap attributes) { + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + std::string activation = + attributes.at("activation").dyn_cast().AsString(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_, y_, bias_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + ir::Attribute attr_trans_x = + ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x); + argument.AddAttribute("trans_x", attr_trans_x); + ir::Attribute attr_trans_y = + ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y); + argument.AddAttribute("trans_y", attr_trans_y); + ir::Attribute attr_activation = + ir::StrAttribute::get(ir::IrContext::Instance(), activation); + argument.AddAttribute("activation", attr_activation); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + paddle::dialect::DenseTensorType y = + y_.type().dyn_cast(); + (void)y; + paddle::dialect::DenseTensorType bias = + bias_.type().dyn_cast(); + (void)bias; + + VLOG(4) << "Builder construction dense_x"; + phi::DenseTensor dense_x( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset())); + VLOG(4) << "Builder construction meta_x"; + phi::MetaTensor meta_x(&dense_x); + + VLOG(4) << "Builder construction dense_y"; + phi::DenseTensor dense_y( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset())); + VLOG(4) << "Builder construction meta_y"; + phi::MetaTensor meta_y(&dense_y); + + VLOG(4) << "Builder construction dense_bias"; + phi::DenseTensor dense_bias( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(bias.dtype()), + bias.dims(), + bias.data_layout(), + bias.lod(), + bias.offset())); + VLOG(4) << "Builder construction meta_bias"; + phi::MetaTensor meta_bias(&dense_bias); + phi::DenseTensor dense_out; + phi::MetaTensor meta_out(&dense_out); + phi::DenseTensor dense_reserve_space; + phi::MetaTensor meta_reserve_space(&dense_reserve_space); + + phi::FusedGemmEpilogueInferMeta( + meta_x, + meta_y, + meta_bias, + trans_x, + trans_y, + activation, + &meta_out, + activation == "none" ? nullptr : &meta_reserve_space); + + std::vector argument_outputs; + ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + + ir::Type reserve_space_dense_tensor_type = + activation == "none" + ? ir::Type() + : paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_reserve_space.dtype()), + dense_reserve_space.dims(), + dense_reserve_space.layout(), + dense_reserve_space.lod(), + dense_reserve_space.offset()); + argument_outputs.push_back(reserve_space_dense_tensor_type); + + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + +void FusedGemmEpilogueOp::Verify() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: " + "FusedGemmEpilogueOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 3u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 3.", input_size)); + PADDLE_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + PADDLE_ENFORCE((*this) + ->operand_source(1) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input.")); + PADDLE_ENFORCE((*this) + ->operand_source(2) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 2th input.")); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + PADDLE_ENFORCE(attributes.count("trans_x") > 0 && + attributes.at("trans_x").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: trans_x is not right.")); + PADDLE_ENFORCE(attributes.count("trans_y") > 0 && + attributes.at("trans_y").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: trans_y is not right.")); + PADDLE_ENFORCE(attributes.count("activation") > 0 && + attributes.at("activation").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: activation is not right.")); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 2u, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 2.", output_size)); + PADDLE_ENFORCE( + (*this)->result(0).type().isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + if (auto output_1_type = (*this)->result(1).type()) { + PADDLE_ENFORCE(output_1_type.isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th output.")); + } + } + VLOG(4) << "End Verifying for: FusedGemmEpilogueOp."; +} + +void FusedGemmEpilogueOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::FusedGemmEpilogueInferMeta); + fn(infer_meta); +} + +const char *FusedGemmEpilogueGradOp::attributes_name[3] = { + "trans_x", "trans_y", "activation_grad"}; + +OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false), + paddle::dialect::OpInputInfo( + "y", "paddle::dialect::DenseTensorType", false, false, false, false), + paddle::dialect::OpInputInfo("reserve_space", + "paddle::dialect::DenseTensorType", + true, + false, + false, + false), + paddle::dialect::OpInputInfo("out_grad", + "paddle::dialect::DenseTensorType", + false, + false, + false, + false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo( + "activation_grad", "ir::StrAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "x_grad", "paddle::dialect::DenseTensorType", false, false), + paddle::dialect::OpOutputInfo( + "y_grad", "paddle::dialect::DenseTensorType", false, false), + paddle::dialect::OpOutputInfo( + "bias_grad", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info("FusedGemmEpilogueGradInferMeta", + {"x", + "y", + "reserve_space", + "out_grad", + "trans_x", + "trans_y", + "activation_grad"}, + {""}, + {""}, + {""}, + {}, + {}, + {}); + + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue_grad"); +} + +void FusedGemmEpilogueGradOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult x_, + ir::OpResult y_, + ir::OpResult reserve_space_, + ir::OpResult out_grad_, + ir::AttributeMap attributes) { + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + std::string activation_grad = + attributes.at("activation_grad").dyn_cast().AsString(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = { + x_, y_, reserve_space_, out_grad_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + ir::Attribute attr_trans_x = + ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x); + argument.AddAttribute("trans_x", attr_trans_x); + ir::Attribute attr_trans_y = + ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y); + argument.AddAttribute("trans_y", attr_trans_y); + ir::Attribute attr_activation_grad = + ir::StrAttribute::get(ir::IrContext::Instance(), activation_grad); + argument.AddAttribute("activation_grad", attr_activation_grad); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + paddle::dialect::DenseTensorType y = + y_.type().dyn_cast(); + (void)y; + paddle::dialect::DenseTensorType reserve_space = + reserve_space_ + ? reserve_space_.type().dyn_cast() + : paddle::dialect::DenseTensorType(); + (void)reserve_space; + paddle::dialect::DenseTensorType out_grad = + out_grad_.type().dyn_cast(); + (void)out_grad; + + VLOG(4) << "Builder construction dense_x"; + phi::DenseTensor dense_x( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset())); + VLOG(4) << "Builder construction meta_x"; + phi::MetaTensor meta_x(&dense_x); + + VLOG(4) << "Builder construction dense_y"; + phi::DenseTensor dense_y( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset())); + VLOG(4) << "Builder construction meta_y"; + phi::MetaTensor meta_y(&dense_y); + + VLOG(4) << "Builder construction dense_reserve_space"; + std::unique_ptr dense_reserve_space = + reserve_space_ + ? std::make_unique( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + paddle::dialect::TransToPhiDataType(reserve_space.dtype()), + reserve_space.dims(), + reserve_space.data_layout(), + reserve_space.lod(), + reserve_space.offset())) + : nullptr; + VLOG(4) << "Builder construction meta_reserve_space"; + phi::MetaTensor meta_reserve_space(dense_reserve_space.get()); + + VLOG(4) << "Builder construction dense_out_grad"; + phi::DenseTensor dense_out_grad( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + paddle::dialect::TransToPhiDataType(out_grad.dtype()), + out_grad.dims(), + out_grad.data_layout(), + out_grad.lod(), + out_grad.offset())); + VLOG(4) << "Builder construction meta_out_grad"; + phi::MetaTensor meta_out_grad(&dense_out_grad); + phi::DenseTensor dense_x_grad; + phi::MetaTensor meta_x_grad(&dense_x_grad); + phi::DenseTensor dense_y_grad; + phi::MetaTensor meta_y_grad(&dense_y_grad); + phi::DenseTensor dense_bias_grad; + phi::MetaTensor meta_bias_grad(&dense_bias_grad); + + phi::FusedGemmEpilogueGradInferMeta(meta_x, + meta_y, + meta_reserve_space, + meta_out_grad, + trans_x, + trans_y, + activation_grad, + &meta_x_grad, + &meta_y_grad, + &meta_bias_grad); + + std::vector argument_outputs; + ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + + ir::Type y_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_y_grad.dtype()), + dense_y_grad.dims(), + dense_y_grad.layout(), + dense_y_grad.lod(), + dense_y_grad.offset()); + argument_outputs.push_back(y_grad_dense_tensor_type); + + ir::Type bias_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_bias_grad.dtype()), + dense_bias_grad.dims(), + dense_bias_grad.layout(), + dense_bias_grad.lod(), + dense_bias_grad.offset()); + argument_outputs.push_back(bias_grad_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + +void FusedGemmEpilogueGradOp::Verify() {} + +void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta); + fn(infer_meta); +} + const char *SplitGradOp::attributes_name[1] = {"axis"}; OpInfoTuple SplitGradOp::GetOpInfo() { @@ -673,4 +1110,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index 6e120317cb461f2a4a76bf7bb99421cd02757d6f..c8a5e1658ec4d42236b30bbb099ef0bf09252521 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -94,6 +94,62 @@ class AddNWithKernelOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.fused_gemm_epilogue"; } + static const char *attributes_name[3]; + static constexpr uint32_t attributes_num = 3; + static OpInfoTuple GetOpInfo(); + + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult x_, + ir::OpResult y_, + ir::OpResult bias_, + ir::AttributeMap attributes); + void Verify(); + ir::Value x() { return operand_source(0); } + ir::Value y() { return operand_source(1); } + ir::Value bias() { return operand_source(2); } + ir::OpResult out() { return result(0); } + ir::OpResult reserve_space() { return result(1); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class FusedGemmEpilogueGradOp + : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.fused_gemm_epilogue_grad"; } + static const char *attributes_name[3]; + static constexpr uint32_t attributes_num = 3; + static OpInfoTuple GetOpInfo(); + + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult x_, + ir::OpResult y_, + ir::OpResult reserve_space_, + ir::OpResult out_grad_, + ir::AttributeMap attributes); + void Verify(); + ir::Value x() { return operand_source(0); } + ir::Value y() { return operand_source(1); } + ir::Value reserve_space() { return operand_source(2); } + ir::Value out_grad() { return operand_source(3); } + ir::OpResult x_grad() { return result(0); } + ir::OpResult y_grad() { return result(1); } + ir::OpResult bias_grad() { return result(2); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + class SplitGradOp : public ir::Op { public: using Op::Op; @@ -141,5 +197,8 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) + IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) #endif diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 2cf00037eb5fcfd55688f75873a3466222f406bf..7e612d6318d36268c688a91b2eddc0e30d92710d 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -70,6 +70,7 @@ class IR_API Block { bool HasOneUse() const; BlockOperand *first_use_addr() { return &first_use_; } + // This is a unsafe funcion, please use it carefully. void ResetOpListOrder(const OpListType &new_op_list); private: diff --git a/paddle/ir/transforms/reorder_block_ops_pass.cc b/paddle/ir/transforms/reorder_block_ops_pass.cc index d922326677985a609b47a367ede5af60a80ea523..91b4b52229f1072d2e1bd6467d53819cabb67d75 100644 --- a/paddle/ir/transforms/reorder_block_ops_pass.cc +++ b/paddle/ir/transforms/reorder_block_ops_pass.cc @@ -22,9 +22,6 @@ namespace { -// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be -// removed by dce pass. -// Now just a naive implementation. class ReorderBlockOpsPass : public ir::Pass { public: ReorderBlockOpsPass() : ir::Pass("ReorderBlockOpsPass", 0) {} diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 49ec2405051fee3add1a4ab479b141803c0dacfb..2e619a3566ff3604524273d33a55b8b01ad3ac5e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -446,6 +446,190 @@ void MultiEncoderXPUInferMeta( } } +void FusedGemmEpilogueInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& bias, + bool trans_x, + bool trans_y, + const std::string& activation, + MetaTensor* out, + MetaTensor* reserve_space) { + const auto& x_dims = x.dims(); + const auto& y_dims = y.dims(); + const auto& bias_dims = bias.dims(); + + PADDLE_ENFORCE_EQ(y_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueOp " + " should be 2, but got %d.", + y_dims.size())); + + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueOp " + " should be >= 2, but got %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + bias_dims.size(), + 1, + phi::errors::InvalidArgument( + "The Input tensor bias's dimension of FusedGemmEpilogueOp " + " should be == 1, but got %d.", + bias_dims.size())); + + PADDLE_ENFORCE_EQ(bias_dims[0], + trans_y ? y_dims[0] : y_dims[1], + phi::errors::InvalidArgument( + "The Input tensor bias's dimension 0" + " should be == Y[-1], but got bias's shape = [%s] " + "and Y's shape = [%s]", + bias_dims, + y_dims)); + + auto x_mat_dims = phi::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1); + + int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1]; + int K_from_y = trans_y ? y_dims[1] : y_dims[0]; + + PADDLE_ENFORCE_EQ( + K_from_x, + K_from_y, + phi::errors::InvalidArgument( + "The last dimension of X should be equal with Y's first dimension." + "But received X[-1] = [%d], Y[0] = [%d].", + K_from_x, + K_from_y)); + + std::vector out_dims; + out_dims.reserve(static_cast(x_dims.size())); + if (trans_x) { + for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]); + } else { + for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]); + } + + if (trans_y) { + out_dims.push_back(y_dims[0]); + } else { + out_dims.push_back(y_dims[1]); + } + out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype(x.dtype()); + + if (reserve_space) { + reserve_space->set_dims(phi::make_ddim(out_dims)); + reserve_space->set_dtype(x.dtype()); + if (activation == "none") { + PADDLE_THROW(phi::errors::InvalidArgument( + "The ReserveSpace would not be used when activation = \"none\"")); + } else { + int min_size_of_n = activation == "relu" ? 128 : 8; + int N_size = trans_y ? y_dims[0] : y_dims[1]; + PADDLE_ENFORCE_EQ(N_size % min_size_of_n, + 0, + phi::errors::InvalidArgument( + "The output dimension N (X(MxK) * Y(KxN) = C(MxN)) " + "should be multiple of %d when auxiliary_key given " + "and activation=%s, but got N = %d.", + min_size_of_n, + activation, + N_size)); + } + } +} + +void FusedGemmEpilogueGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& reserve_space, + const MetaTensor& out_grad, + bool trans_x, + bool trans_y, + const std::string& activation_grad, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* bias_grad) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto dout_dims = out_grad.dims(); + + PADDLE_ENFORCE_GE( + dout_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input tensor DOut's dimension of FusedGemmEpilogueGradOp " + " should be >= 2, but got %d.", + dout_dims.size())); + + PADDLE_ENFORCE_EQ( + y_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueGradOp " + " should be 2, but got %d.", + y_dims.size())); + + PADDLE_ENFORCE_GE( + x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueGradOp " + " should be >= 2, but got %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + dout_dims.size(), + x_dims.size(), + phi::errors::InvalidArgument( + "The Input tensor DOut's and X's dimension of " + "FusedGemmEpilogueGradOp " + " should be the same, but got DOut's dim = %d and X's = %d.", + dout_dims.size(), + x_dims.size())); + + auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1); + auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); + + PADDLE_ENFORCE_EQ( + dout_mat_dims[1], + trans_y ? y_dims[0] : y_dims[1], + phi::errors::InvalidArgument( + "The last dimension of DOut should be equal with Y's last" + "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", + dout_mat_dims[1], + y_dims[1])); + + PADDLE_ENFORCE_EQ( + dout_mat_dims[0], + trans_x ? x_mat_dims[1] : x_mat_dims[0], + phi::errors::InvalidArgument( + "The first dimension of DOut should be equal with X's first" + "dimension. But received DOut[0] = [%d], Y[0] = [%d].", + dout_mat_dims[0], + x_mat_dims[0])); + + if (activation_grad != "none" && !reserve_space) { + PADDLE_THROW(phi::errors::InvalidArgument( + "The ReserveSpace should not be empty. " + "when activation == {relu_grad, gelu_grad}.")); + } + + if (x_grad) { + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); + } + y_grad->set_dims(y_dims); + y_grad->set_dtype(y.dtype()); + + if (bias_grad) { + int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1]; + bias_grad->set_dims(phi::make_ddim({dbias_dim})); + bias_grad->set_dtype(y.dtype()); + } +} + void FusedMultiTransformerXpuInferMeta( const MetaTensor& x, const std::vector& ln_scale, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index dd5fcfcbf85895bd9eaeb297ab85d1c7b8969fa9..208f23c681febcd1d90e237a90f7f8d4c5d93a3e 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -123,6 +123,26 @@ void MultiEncoderXPUInferMeta( MetaTensor* x_fp16, MetaTensor* out_fp16); +void FusedGemmEpilogueInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& bias, + bool trans_x, + bool trans_y, + const std::string& activation, + MetaTensor* out, + MetaTensor* reserve_space); + +void FusedGemmEpilogueGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& reserve_space, + const MetaTensor& out_grad, + bool trans_x, + bool trans_y, + const std::string& activation_grad, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* bias_grad); + void FusedMultiTransformerXpuInferMeta( const MetaTensor& x, const std::vector& ln_scale,