未验证 提交 eded6013 编写于 作者: C Chen Weihang 提交者: GitHub

Simplify conv_mkldnn op registration (#46907)

* simplify conv_mkldnn op registration

* remove custom type value in conv grad op
上级 2010bdc3
......@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
......
......@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
framework::LibraryType::kMKLDNN);
}
#endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
kConvMKLDNNFP32);
framework::LibraryType::kMKLDNN);
}
#endif
......
......@@ -20,6 +20,8 @@
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/core/visit_type.h"
namespace paddle {
namespace operators {
namespace {
......@@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT
} // anonymous namespace
template <typename T, typename K>
#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -828,7 +845,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler(
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx,
dev_ctx,
mkldnn_engine,
......@@ -861,7 +880,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
......@@ -870,6 +890,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
astream.wait();
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
template <typename T_out>
......@@ -905,7 +926,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler(
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx,
dev_ctx,
mkldnn_engine,
......@@ -923,7 +946,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups");
int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
: 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, true, scale_weights_data, mask_reorder);
......@@ -959,17 +983,19 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
std::vector<float> bias_scales;
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales));
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
if (ctx.HasAttr("Bias_scales")) {
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales));
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
} else {
p_scales_tuple = handler.get_int8_bias_scales(ctx);
}
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias,
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias,
true,
std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
......@@ -985,10 +1011,28 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
}
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
};
template <typename T, typename K>
#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::BFLOAT16, \
::phi::dtype::bfloat16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -1013,8 +1057,10 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return;
PD_VISIT_FLOAT_AND_BF16_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
// TODO(jczaja): Are all tensors really needed?
ConvMKLDNNHandlerT<T, K, T> handler(
ConvMKLDNNHandlerT<T, data_t, T> handler(
ctx,
dev_ctx,
ctx.GetPlace(),
......@@ -1053,16 +1099,17 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
astream.wait();
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data
// otherwise there will be problems in next operators working on
// this data
if (g > 1) {
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
// in OneDNN groups in convolution are treated as separate
// dimension which is not the case in paddlepaddle
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(filter->dtype()));
// for 3d conv with groups (six dimensional data reorder to goidhw)
// for 2d conv with groups (five dimensional data reorder to goihw)
// auto weights_tz = phi::vectorize(filter->dims());
// for 3d conv with groups (six dimensional data reorder to
// goidhw) for 2d conv with groups (five dimensional data reorder
// to goihw) auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims();
dnnl::memory::format_tag out_format =
......@@ -1073,11 +1120,11 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
framework::TransToProtoVarType(filter->dtype()),
in_type,
mkldnn_engine);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());
auto reorder_dst_memory_p = handler.AcquireDstMemory(
filter_grad, out_format, ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p);
auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p,
diff_weights_memory_p);
{
platform::RecordEvent record_reorder(
......@@ -1090,14 +1137,14 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
astream.wait();
}
// So here we have a data in goihw , which can be interpreted as OIHW
// (OIDHW for conv3d)
// because filter_grad shape is set for OIHW (OIDHW for conv3d)
// So here we have a data in goihw , which can be interpreted as
// OIHW (OIDHW for conv3d) because filter_grad shape is set for
// OIHW (OIDHW for conv3d)
dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw;
filter_grad->set_mem_desc(
dnnl::memory::desc(phi::vectorize<int64_t>(filter_grad->dims()),
filter_grad->set_mem_desc(dnnl::memory::desc(
phi::vectorize<int64_t>(filter_grad->dims()),
in_type,
target_format));
} else {
......@@ -1126,6 +1173,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
input_grad->set_mem_desc(diff_src_memory_p->get_desc());
}
}));
}
};
......@@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8WS8,
ops::kConvMKLDNNINT8WS8,
ops::ConvMKLDNNOpKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
S8WS8,
ops::kConvMKLDNNINT8WS8,
ops::ConvMKLDNNOpKernel<int8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
REGISTER_OP_KERNEL(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t, float>);
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
REGISTER_OP_KERNEL(conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t, float>);
ops::ConvMKLDNNGradOpKernel<float>,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d_grad,
REGISTER_OP_KERNEL(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
depthwise_conv2d_grad,
REGISTER_OP_KERNEL(depthwise_conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);
ops::ConvMKLDNNGradOpKernel<float>,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d,
REGISTER_OP_KERNEL(conv3d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
ops::ConvMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad,
REGISTER_OP_KERNEL(conv3d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
ops::ConvMKLDNNGradOpKernel<float>);
......@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
USE_OP_DEVICE_KERNEL(conv2d, MKLDNN);
namespace paddle {
namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册