未验证 提交 720018bb 编写于 作者: Z zyfncg 提交者: GitHub

[IR] Add FusedGemmEpilogueOp in new IR (#57039)

* add FusedGemmEpilogueOp in new ir

* fix conflict
上级 af6324aa
......@@ -52,8 +52,11 @@ void PaddleDialect::initialize() {
RegisterOps<paddle::dialect::AddNOp,
paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp,
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp,
paddle::dialect::SplitGradOp,
paddle::dialect::IfOp>();
RegisterInterfaces<ParameterConvertInterface>();
}
......
......@@ -24,6 +24,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
......@@ -409,6 +410,442 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}
const char *FusedGemmEpilogueOp::attributes_name[3] = {
"trans_x", "trans_y", "activation"};
OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo(
"x", "paddle::dialect::DenseTensorType", false, false, false, false),
paddle::dialect::OpInputInfo(
"y", "paddle::dialect::DenseTensorType", false, false, false, false),
paddle::dialect::OpInputInfo("bias",
"paddle::dialect::DenseTensorType",
false,
false,
false,
false)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""),
paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""),
paddle::dialect::OpAttributeInfo("activation", "ir::StrAttribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
paddle::dialect::OpOutputInfo(
"out", "paddle::dialect::DenseTensorType", false, false),
paddle::dialect::OpOutputInfo(
"reserve_space", "paddle::dialect::DenseTensorType", true, false)};
paddle::dialect::OpRunTimeInfo run_time_info(
"FusedGemmEpilogueInferMeta",
{"x", "y", "bias", "trans_x", "trans_y", "activation"},
{""},
{""},
{""},
{},
{},
{});
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue");
}
void FusedGemmEpilogueOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult x_,
ir::OpResult y_,
ir::OpResult bias_,
ir::AttributeMap attributes) {
bool trans_x = attributes.at("trans_x").dyn_cast<ir::BoolAttribute>().data();
bool trans_y = attributes.at("trans_y").dyn_cast<ir::BoolAttribute>().data();
std::string activation =
attributes.at("activation").dyn_cast<ir::StrAttribute>().AsString();
VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {x_, y_, bias_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
VLOG(4) << "Builder construction attributes";
ir::Attribute attr_trans_x =
ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x);
argument.AddAttribute("trans_x", attr_trans_x);
ir::Attribute attr_trans_y =
ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y);
argument.AddAttribute("trans_y", attr_trans_y);
ir::Attribute attr_activation =
ir::StrAttribute::get(ir::IrContext::Instance(), activation);
argument.AddAttribute("activation", attr_activation);
VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)x;
paddle::dialect::DenseTensorType y =
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)y;
paddle::dialect::DenseTensorType bias =
bias_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)bias;
VLOG(4) << "Builder construction dense_x";
phi::DenseTensor dense_x(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset()));
VLOG(4) << "Builder construction meta_x";
phi::MetaTensor meta_x(&dense_x);
VLOG(4) << "Builder construction dense_y";
phi::DenseTensor dense_y(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset()));
VLOG(4) << "Builder construction meta_y";
phi::MetaTensor meta_y(&dense_y);
VLOG(4) << "Builder construction dense_bias";
phi::DenseTensor dense_bias(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(bias.dtype()),
bias.dims(),
bias.data_layout(),
bias.lod(),
bias.offset()));
VLOG(4) << "Builder construction meta_bias";
phi::MetaTensor meta_bias(&dense_bias);
phi::DenseTensor dense_out;
phi::MetaTensor meta_out(&dense_out);
phi::DenseTensor dense_reserve_space;
phi::MetaTensor meta_reserve_space(&dense_reserve_space);
phi::FusedGemmEpilogueInferMeta(
meta_x,
meta_y,
meta_bias,
trans_x,
trans_y,
activation,
&meta_out,
activation == "none" ? nullptr : &meta_reserve_space);
std::vector<ir::Type> argument_outputs;
ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.dims(),
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
ir::Type reserve_space_dense_tensor_type =
activation == "none"
? ir::Type()
: paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_reserve_space.dtype()),
dense_reserve_space.dims(),
dense_reserve_space.layout(),
dense_reserve_space.lod(),
dense_reserve_space.offset());
argument_outputs.push_back(reserve_space_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}
void FusedGemmEpilogueOp::Verify() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"FusedGemmEpilogueOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
3u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 3.", input_size));
PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
PADDLE_ENFORCE((*this)
->operand_source(2)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input."));
}
VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("trans_x") > 0 &&
attributes.at("trans_x").isa<ir::BoolAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: trans_x is not right."));
PADDLE_ENFORCE(attributes.count("trans_y") > 0 &&
attributes.at("trans_y").isa<ir::BoolAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: trans_y is not right."));
PADDLE_ENFORCE(attributes.count("activation") > 0 &&
attributes.at("activation").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: activation is not right."));
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 2.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
if (auto output_1_type = (*this)->result(1).type()) {
PADDLE_ENFORCE(output_1_type.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th output."));
}
}
VLOG(4) << "End Verifying for: FusedGemmEpilogueOp.";
}
void FusedGemmEpilogueOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::FusedGemmEpilogueInferMeta);
fn(infer_meta);
}
const char *FusedGemmEpilogueGradOp::attributes_name[3] = {
"trans_x", "trans_y", "activation_grad"};
OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo(
"x", "paddle::dialect::DenseTensorType", false, false, false, false),
paddle::dialect::OpInputInfo(
"y", "paddle::dialect::DenseTensorType", false, false, false, false),
paddle::dialect::OpInputInfo("reserve_space",
"paddle::dialect::DenseTensorType",
true,
false,
false,
false),
paddle::dialect::OpInputInfo("out_grad",
"paddle::dialect::DenseTensorType",
false,
false,
false,
false)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""),
paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""),
paddle::dialect::OpAttributeInfo(
"activation_grad", "ir::StrAttribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
paddle::dialect::OpOutputInfo(
"x_grad", "paddle::dialect::DenseTensorType", false, false),
paddle::dialect::OpOutputInfo(
"y_grad", "paddle::dialect::DenseTensorType", false, false),
paddle::dialect::OpOutputInfo(
"bias_grad", "paddle::dialect::DenseTensorType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info("FusedGemmEpilogueGradInferMeta",
{"x",
"y",
"reserve_space",
"out_grad",
"trans_x",
"trans_y",
"activation_grad"},
{""},
{""},
{""},
{},
{},
{});
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue_grad");
}
void FusedGemmEpilogueGradOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult x_,
ir::OpResult y_,
ir::OpResult reserve_space_,
ir::OpResult out_grad_,
ir::AttributeMap attributes) {
bool trans_x = attributes.at("trans_x").dyn_cast<ir::BoolAttribute>().data();
bool trans_y = attributes.at("trans_y").dyn_cast<ir::BoolAttribute>().data();
std::string activation_grad =
attributes.at("activation_grad").dyn_cast<ir::StrAttribute>().AsString();
VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {
x_, y_, reserve_space_, out_grad_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
VLOG(4) << "Builder construction attributes";
ir::Attribute attr_trans_x =
ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x);
argument.AddAttribute("trans_x", attr_trans_x);
ir::Attribute attr_trans_y =
ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y);
argument.AddAttribute("trans_y", attr_trans_y);
ir::Attribute attr_activation_grad =
ir::StrAttribute::get(ir::IrContext::Instance(), activation_grad);
argument.AddAttribute("activation_grad", attr_activation_grad);
VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)x;
paddle::dialect::DenseTensorType y =
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)y;
paddle::dialect::DenseTensorType reserve_space =
reserve_space_
? reserve_space_.type().dyn_cast<paddle::dialect::DenseTensorType>()
: paddle::dialect::DenseTensorType();
(void)reserve_space;
paddle::dialect::DenseTensorType out_grad =
out_grad_.type().dyn_cast<paddle::dialect::DenseTensorType>();
(void)out_grad;
VLOG(4) << "Builder construction dense_x";
phi::DenseTensor dense_x(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset()));
VLOG(4) << "Builder construction meta_x";
phi::MetaTensor meta_x(&dense_x);
VLOG(4) << "Builder construction dense_y";
phi::DenseTensor dense_y(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset()));
VLOG(4) << "Builder construction meta_y";
phi::MetaTensor meta_y(&dense_y);
VLOG(4) << "Builder construction dense_reserve_space";
std::unique_ptr<phi::DenseTensor> dense_reserve_space =
reserve_space_
? std::make_unique<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(
paddle::dialect::TransToPhiDataType(reserve_space.dtype()),
reserve_space.dims(),
reserve_space.data_layout(),
reserve_space.lod(),
reserve_space.offset()))
: nullptr;
VLOG(4) << "Builder construction meta_reserve_space";
phi::MetaTensor meta_reserve_space(dense_reserve_space.get());
VLOG(4) << "Builder construction dense_out_grad";
phi::DenseTensor dense_out_grad(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(
paddle::dialect::TransToPhiDataType(out_grad.dtype()),
out_grad.dims(),
out_grad.data_layout(),
out_grad.lod(),
out_grad.offset()));
VLOG(4) << "Builder construction meta_out_grad";
phi::MetaTensor meta_out_grad(&dense_out_grad);
phi::DenseTensor dense_x_grad;
phi::MetaTensor meta_x_grad(&dense_x_grad);
phi::DenseTensor dense_y_grad;
phi::MetaTensor meta_y_grad(&dense_y_grad);
phi::DenseTensor dense_bias_grad;
phi::MetaTensor meta_bias_grad(&dense_bias_grad);
phi::FusedGemmEpilogueGradInferMeta(meta_x,
meta_y,
meta_reserve_space,
meta_out_grad,
trans_x,
trans_y,
activation_grad,
&meta_x_grad,
&meta_y_grad,
&meta_bias_grad);
std::vector<ir::Type> argument_outputs;
ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.dims(),
dense_x_grad.layout(),
dense_x_grad.lod(),
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
ir::Type y_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_y_grad.dtype()),
dense_y_grad.dims(),
dense_y_grad.layout(),
dense_y_grad.lod(),
dense_y_grad.offset());
argument_outputs.push_back(y_grad_dense_tensor_type);
ir::Type bias_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_bias_grad.dtype()),
dense_bias_grad.dims(),
dense_bias_grad.layout(),
dense_bias_grad.lod(),
dense_bias_grad.offset());
argument_outputs.push_back(bias_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}
void FusedGemmEpilogueGradOp::Verify() {}
void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta);
fn(infer_meta);
}
const char *SplitGradOp::attributes_name[1] = {"axis"};
OpInfoTuple SplitGradOp::GetOpInfo() {
......@@ -673,4 +1110,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
......@@ -94,6 +94,62 @@ class AddNWithKernelOp : public ir::Op<AddNWithKernelOp,
static void InferMeta(phi::InferMetaContext *infer_meta);
};
class FusedGemmEpilogueOp : public ir::Op<FusedGemmEpilogueOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd.fused_gemm_epilogue"; }
static const char *attributes_name[3];
static constexpr uint32_t attributes_num = 3;
static OpInfoTuple GetOpInfo();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult x_,
ir::OpResult y_,
ir::OpResult bias_,
ir::AttributeMap attributes);
void Verify();
ir::Value x() { return operand_source(0); }
ir::Value y() { return operand_source(1); }
ir::Value bias() { return operand_source(2); }
ir::OpResult out() { return result(0); }
ir::OpResult reserve_space() { return result(1); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};
class FusedGemmEpilogueGradOp
: public ir::Op<FusedGemmEpilogueGradOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd.fused_gemm_epilogue_grad"; }
static const char *attributes_name[3];
static constexpr uint32_t attributes_num = 3;
static OpInfoTuple GetOpInfo();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult x_,
ir::OpResult y_,
ir::OpResult reserve_space_,
ir::OpResult out_grad_,
ir::AttributeMap attributes);
void Verify();
ir::Value x() { return operand_source(0); }
ir::Value y() { return operand_source(1); }
ir::Value reserve_space() { return operand_source(2); }
ir::Value out_grad() { return operand_source(3); }
ir::OpResult x_grad() { return result(0); }
ir::OpResult y_grad() { return result(1); }
ir::OpResult bias_grad() { return result(2); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};
class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
public:
using Op::Op;
......@@ -141,5 +197,8 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
#endif
......@@ -70,6 +70,7 @@ class IR_API Block {
bool HasOneUse() const;
BlockOperand *first_use_addr() { return &first_use_; }
// This is a unsafe funcion, please use it carefully.
void ResetOpListOrder(const OpListType &new_op_list);
private:
......
......@@ -22,9 +22,6 @@
namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class ReorderBlockOpsPass : public ir::Pass {
public:
ReorderBlockOpsPass() : ir::Pass("ReorderBlockOpsPass", 0) {}
......
......@@ -446,6 +446,190 @@ void MultiEncoderXPUInferMeta(
}
}
void FusedGemmEpilogueInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation,
MetaTensor* out,
MetaTensor* reserve_space) {
const auto& x_dims = x.dims();
const auto& y_dims = y.dims();
const auto& bias_dims = bias.dims();
PADDLE_ENFORCE_EQ(y_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input tensor Y's dimension of FusedGemmEpilogueOp "
" should be 2, but got %d.",
y_dims.size()));
PADDLE_ENFORCE_GE(x_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input tensor X's dimension of FusedGemmEpilogueOp "
" should be >= 2, but got %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
bias_dims.size(),
1,
phi::errors::InvalidArgument(
"The Input tensor bias's dimension of FusedGemmEpilogueOp "
" should be == 1, but got %d.",
bias_dims.size()));
PADDLE_ENFORCE_EQ(bias_dims[0],
trans_y ? y_dims[0] : y_dims[1],
phi::errors::InvalidArgument(
"The Input tensor bias's dimension 0"
" should be == Y[-1], but got bias's shape = [%s] "
"and Y's shape = [%s]",
bias_dims,
y_dims));
auto x_mat_dims = phi::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1);
int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1];
int K_from_y = trans_y ? y_dims[1] : y_dims[0];
PADDLE_ENFORCE_EQ(
K_from_x,
K_from_y,
phi::errors::InvalidArgument(
"The last dimension of X should be equal with Y's first dimension."
"But received X[-1] = [%d], Y[0] = [%d].",
K_from_x,
K_from_y));
std::vector<int64_t> out_dims;
out_dims.reserve(static_cast<size_t>(x_dims.size()));
if (trans_x) {
for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]);
} else {
for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]);
}
if (trans_y) {
out_dims.push_back(y_dims[0]);
} else {
out_dims.push_back(y_dims[1]);
}
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype(x.dtype());
if (reserve_space) {
reserve_space->set_dims(phi::make_ddim(out_dims));
reserve_space->set_dtype(x.dtype());
if (activation == "none") {
PADDLE_THROW(phi::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
} else {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(N_size % min_size_of_n,
0,
phi::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n,
activation,
N_size));
}
}
}
void FusedGemmEpilogueGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& reserve_space,
const MetaTensor& out_grad,
bool trans_x,
bool trans_y,
const std::string& activation_grad,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* bias_grad) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto dout_dims = out_grad.dims();
PADDLE_ENFORCE_GE(
dout_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input tensor DOut's dimension of FusedGemmEpilogueGradOp "
" should be >= 2, but got %d.",
dout_dims.size()));
PADDLE_ENFORCE_EQ(
y_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input tensor Y's dimension of FusedGemmEpilogueGradOp "
" should be 2, but got %d.",
y_dims.size()));
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input tensor X's dimension of FusedGemmEpilogueGradOp "
" should be >= 2, but got %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
dout_dims.size(),
x_dims.size(),
phi::errors::InvalidArgument(
"The Input tensor DOut's and X's dimension of "
"FusedGemmEpilogueGradOp "
" should be the same, but got DOut's dim = %d and X's = %d.",
dout_dims.size(),
x_dims.size()));
auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ(
dout_mat_dims[1],
trans_y ? y_dims[0] : y_dims[1],
phi::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1],
y_dims[1]));
PADDLE_ENFORCE_EQ(
dout_mat_dims[0],
trans_x ? x_mat_dims[1] : x_mat_dims[0],
phi::errors::InvalidArgument(
"The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d].",
dout_mat_dims[0],
x_mat_dims[0]));
if (activation_grad != "none" && !reserve_space) {
PADDLE_THROW(phi::errors::InvalidArgument(
"The ReserveSpace should not be empty. "
"when activation == {relu_grad, gelu_grad}."));
}
if (x_grad) {
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}
y_grad->set_dims(y_dims);
y_grad->set_dtype(y.dtype());
if (bias_grad) {
int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1];
bias_grad->set_dims(phi::make_ddim({dbias_dim}));
bias_grad->set_dtype(y.dtype());
}
}
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
......
......@@ -123,6 +123,26 @@ void MultiEncoderXPUInferMeta(
MetaTensor* x_fp16,
MetaTensor* out_fp16);
void FusedGemmEpilogueInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation,
MetaTensor* out,
MetaTensor* reserve_space);
void FusedGemmEpilogueGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& reserve_space,
const MetaTensor& out_grad,
bool trans_x,
bool trans_y,
const std::string& activation_grad,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* bias_grad);
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册