未验证 提交 aede713a 编写于 作者: H HongyuJia 提交者: GitHub

[MKLDNN] Delete mkldnn hard code of mul (#47166)

* delete GetExpectedKernelType mkldnn of mul_grad

* update mkldnn_op_list, remove mul_grad

* delete GetExpectedKernelType mkldnn of mul
上级 9f666615
......@@ -37,9 +37,6 @@ using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory {
public:
......@@ -340,6 +337,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
return prim_creator;
}
/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT>
inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
const ExecutionContext &ctx,
......@@ -363,8 +361,8 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
}
}
/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT>
/* XT: input x data type */
template <typename XT>
class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override {
......@@ -381,7 +379,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
Tensor *out = ctx.Output<phi::DenseTensor>("Out");
auto out_dims = out->dims();
auto mul = GetMulPrimitive<XT, YT>(dev_ctx, ctx, x, y, out, mkldnn_engine);
auto mul =
GetMulPrimitive<XT, float>(dev_ctx, ctx, x, y, out, mkldnn_engine);
if (out_dims.size() != 2) {
out->Resize(out_dims);
......@@ -393,7 +392,7 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
}
};
template <typename XT, typename YT>
template <typename XT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
......@@ -411,7 +410,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool trans_y,
Tensor *out) const {
static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT, YT, XT> handler(ctx,
MatMulV2MKLDNNHandler<XT, XT, XT> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
......@@ -487,13 +486,12 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
}
};
template <typename XT, typename YT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
template <typename XT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
private:
template <typename OT = XT>
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
......@@ -569,57 +567,17 @@ class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kMULMKLDNNINT8,
ops::MulMKLDNNINT8Kernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kMULMKLDNNINT8,
ops::MulMKLDNNINT8Kernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(mul,
MKLDNN,
::paddle::platform::CPUPlace,
ops::MulMKLDNNINT8Kernel<uint8_t, float>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float, float>);
ops::MulMKLDNNINT8Kernel<uint8_t>,
ops::MulMKLDNNINT8Kernel<int8_t>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float>);
REGISTER_OP_KERNEL(mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float>);
......@@ -31,9 +31,6 @@ namespace operators {
using framework::OpKernelType;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -41,29 +38,6 @@ class MulOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -136,29 +110,6 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -69,9 +69,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"prior_box",
"mul",
"mul_grad"};
"prior_box"};
inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册