未验证 提交 aede713a 编写于 作者: H HongyuJia 提交者: GitHub

[MKLDNN] Delete mkldnn hard code of mul (#47166)

* delete GetExpectedKernelType mkldnn of mul_grad

* update mkldnn_op_list, remove mul_grad

* delete GetExpectedKernelType mkldnn of mul
上级 9f666615
...@@ -37,9 +37,6 @@ using dnnl::memory; ...@@ -37,9 +37,6 @@ using dnnl::memory;
using dnnl::prop_kind; using dnnl::prop_kind;
using dnnl::stream; using dnnl::stream;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory { class MulPrimitiveFactory {
public: public:
...@@ -340,6 +337,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -340,6 +337,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
return prim_creator; return prim_creator;
} }
/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT> template <typename XT, typename YT>
inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
const ExecutionContext &ctx, const ExecutionContext &ctx,
...@@ -363,8 +361,8 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, ...@@ -363,8 +361,8 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
} }
} }
/* XT: input x data type, YT: input y data type */ /* XT: input x data type */
template <typename XT, typename YT> template <typename XT>
class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
public: public:
void Compute(const ExecutionContext &ctx) const override { void Compute(const ExecutionContext &ctx) const override {
...@@ -381,7 +379,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { ...@@ -381,7 +379,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
Tensor *out = ctx.Output<phi::DenseTensor>("Out"); Tensor *out = ctx.Output<phi::DenseTensor>("Out");
auto out_dims = out->dims(); auto out_dims = out->dims();
auto mul = GetMulPrimitive<XT, YT>(dev_ctx, ctx, x, y, out, mkldnn_engine); auto mul =
GetMulPrimitive<XT, float>(dev_ctx, ctx, x, y, out, mkldnn_engine);
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
out->Resize(out_dims); out->Resize(out_dims);
...@@ -393,7 +392,7 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { ...@@ -393,7 +392,7 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
} }
}; };
template <typename XT, typename YT> template <typename XT>
class MulMKLDNNKernel : public framework::OpKernel<XT> { class MulMKLDNNKernel : public framework::OpKernel<XT> {
public: public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
...@@ -411,7 +410,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -411,7 +410,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool trans_y, bool trans_y,
Tensor *out) const { Tensor *out) const {
static const std::vector<int64_t> vec_placeholder; static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT, YT, XT> handler(ctx, MatMulV2MKLDNNHandler<XT, XT, XT> handler(ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x_dims, x_dims,
...@@ -487,13 +486,12 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -487,13 +486,12 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
} }
}; };
template <typename XT, typename YT> template <typename XT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> { class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT> {
public: public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
private: private:
template <typename OT = XT>
void RunKernel(const ExecutionContext &ctx) const { void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
...@@ -569,57 +567,17 @@ class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> { ...@@ -569,57 +567,17 @@ class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kMULMKLDNNINT8,
ops::MulMKLDNNINT8Kernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kMULMKLDNNINT8,
ops::MulMKLDNNINT8Kernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(mul, REGISTER_OP_KERNEL(mul,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
ops::MulMKLDNNINT8Kernel<uint8_t, float>, ops::MulMKLDNNINT8Kernel<uint8_t>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16, ops::MulMKLDNNINT8Kernel<int8_t>,
paddle::platform::bfloat16>, ops::MulMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float, float>); ops::MulMKLDNNKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( REGISTER_OP_KERNEL(mul_grad,
mul_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
BF16, ops::MulGradMKLDNNKernel<paddle::platform::bfloat16>,
ops::kMULMKLDNNFP32, ops::MulGradMKLDNNKernel<float>);
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float, float>);
...@@ -31,9 +31,6 @@ namespace operators { ...@@ -31,9 +31,6 @@ namespace operators {
using framework::OpKernelType; using framework::OpKernelType;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -41,29 +38,6 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -41,29 +38,6 @@ class MulOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -136,29 +110,6 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -136,29 +110,6 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -69,9 +69,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = { ...@@ -69,9 +69,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"reduce_sum_grad", "reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one. // need to analysis and solve them one-by-one.
"prior_box", "prior_box"};
"mul",
"mul_grad"};
inline bool in_mkldnn_white_list(const std::string& op_name) { inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end(); return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册