未验证 提交 eded6013 编写于 作者: C Chen Weihang 提交者: GitHub

Simplify conv_mkldnn op registration (#46907)

* simplify conv_mkldnn op registration

* remove custom type value in conv grad op
上级 2010bdc3
...@@ -511,13 +511,6 @@ function(op_library TARGET) ...@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file} file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n") "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
......
...@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN, framework::LibraryType::kMKLDNN);
customized_type_value);
} }
#endif #endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return framework::OpKernelType(data_type, return framework::OpKernelType(data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN, framework::LibraryType::kMKLDNN);
kConvMKLDNNFP32);
} }
#endif #endif
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/core/visit_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
...@@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT ...@@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT
} // anonymous namespace } // anonymous namespace
template <typename T, typename K> #define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNOpKernel : public framework::OpKernel<T> { class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -828,48 +845,52 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -828,48 +845,52 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr; ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output"); auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler( PD_VISIT_FLOAT_AND_INT8_TYPES(
ctx, filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
dev_ctx, ConvMKLDNNHandlerT<T, data_t, T_out> handler(
mkldnn_engine, ctx,
ctx.GetPlace(), dev_ctx,
input, mkldnn_engine,
filter, ctx.GetPlace(),
bias, input,
output, filter,
ctx.InputName("Input") + ctx.InputName("Filter")); bias,
output,
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); ctx.InputName("Input") + ctx.InputName("Filter"));
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
std::shared_ptr<dnnl::memory> dst_memory_p; filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData"); std::shared_ptr<dnnl::memory> dst_memory_p;
dst_memory_p = if (fuse_residual_conn) {
handler.AcquireDstMemoryWithResidual(output, residual_param); auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
} else { dst_memory_p =
dst_memory_p = handler.template AcquireDstMemory<T_out>(output); handler.AcquireDstMemoryWithResidual(output, residual_param);
} } else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
auto conv_p = handler.AcquireForwardPrimitive(); }
std::unordered_map<int, dnnl::memory> args = { auto conv_p = handler.AcquireForwardPrimitive();
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p}, std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
if (bias) { {DNNL_ARG_DST, *dst_memory_p}};
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({DNNL_ARG_BIAS, *bias_memory_p}); if (bias) {
} auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, is_test);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); args.insert({DNNL_ARG_BIAS, *bias_memory_p});
conv_p->execute(astream, args); }
astream.wait();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_memory_p->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
}));
} }
template <typename T_out> template <typename T_out>
...@@ -905,90 +926,113 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -905,90 +926,113 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr; ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output"); auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler( PD_VISIT_FLOAT_AND_INT8_TYPES(
ctx, filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
dev_ctx, ConvMKLDNNHandlerT<T, data_t, T_out> handler(
mkldnn_engine, ctx,
ctx.GetPlace(), dev_ctx,
input, mkldnn_engine,
filter, ctx.GetPlace(),
bias, input,
output, filter,
ctx.InputName("Input") + ctx.InputName("Filter")); bias,
output,
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); ctx.InputName("Input") + ctx.InputName("Filter"));
const auto& scale_weights_data = auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
ctx.Attr<std::vector<float>>("Scale_weights");
const bool is_multi_channel = scale_weights_data.size() > 1; const auto& scale_weights_data =
const int& groups = ctx.Attr<int>("groups"); ctx.Attr<std::vector<float>>("Scale_weights");
int mask_reorder = const bool is_multi_channel = scale_weights_data.size() > 1;
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; const int& groups = ctx.Attr<int>("groups");
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( int mask_reorder =
filter, groups, false, true, scale_weights_data, mask_reorder); is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
: 0;
std::shared_ptr<dnnl::memory> dst_memory_p; auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
if (fuse_residual_conn) { filter, groups, false, true, scale_weights_data, mask_reorder);
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
PADDLE_ENFORCE_EQ( std::shared_ptr<dnnl::memory> dst_memory_p;
output->dims(), if (fuse_residual_conn) {
residual_param->dims(), auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Output and elementwise parameter need to have the " output->dims(),
"same dimension sizes, but got output's dimension = %d" residual_param->dims(),
" and residual param's dimension =%d .", platform::errors::InvalidArgument(
output->dims().size(), "Output and elementwise parameter need to have the "
residual_param->dims().size())); "same dimension sizes, but got output's dimension = %d"
dst_memory_p = " and residual param's dimension =%d .",
handler.AcquireDstMemoryWithResidual(output, residual_param); output->dims().size(),
need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() == residual_param->dims().size()));
dnnl::memory::data_type::s8) && dst_memory_p =
unsigned_output; handler.AcquireDstMemoryWithResidual(output, residual_param);
} else { need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
dst_memory_p = handler.template AcquireDstMemory<T_out>(output); dnnl::memory::data_type::s8) &&
} unsigned_output;
} else {
auto conv_p = handler.AcquireForwardPrimitive(); dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
}
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, auto conv_p = handler.AcquireForwardPrimitive();
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
if (bias) { {DNNL_ARG_WEIGHTS, *weights_memory_p},
std::vector<float> bias_scales; {DNNL_ARG_DST, *dst_memory_p}};
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>( if (bias) {
std::make_tuple(static_cast<float>(mask_reorder), bias_scales)); std::vector<float> bias_scales;
if (ctx.HasAttr("Bias_scales")) { auto p_scales_tuple =
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales"); std::make_shared<std::tuple<float, std::vector<float>>>(
p_scales_tuple = std::make_tuple(static_cast<float>(mask_reorder),
std::make_shared<std::tuple<float, std::vector<float>>>( bias_scales));
std::make_tuple(static_cast<float>(mask_reorder), bias_scales)); if (ctx.HasAttr("Bias_scales")) {
} else { bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
p_scales_tuple = handler.get_int8_bias_scales(ctx); p_scales_tuple =
} std::make_shared<std::tuple<float, std::vector<float>>>(
auto bias_memory_p = std::make_tuple(static_cast<float>(mask_reorder),
handler.AcquireBiasMemoryWithReorder(bias, bias_scales));
true, } else {
std::get<1>(*p_scales_tuple), p_scales_tuple = handler.get_int8_bias_scales(ctx);
std::get<0>(*p_scales_tuple)); }
args.insert({DNNL_ARG_BIAS, *bias_memory_p}); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
} bias,
true,
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); std::get<1>(*p_scales_tuple),
conv_p->execute(astream, args); std::get<0>(*p_scales_tuple));
astream.wait(); args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
if (need_s8_to_u8) { if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_mem_desc(dst_memory_p->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
}));
} }
}; };
template <typename T, typename K> #define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::BFLOAT16, \
::phi::dtype::bfloat16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -1013,119 +1057,123 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1013,119 +1057,123 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
// TODO(jczaja): Are all tensors really needed? PD_VISIT_FLOAT_AND_BF16_TYPES(
ConvMKLDNNHandlerT<T, K, T> handler( filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ctx, // TODO(jczaja): Are all tensors really needed?
dev_ctx, ConvMKLDNNHandlerT<T, data_t, T> handler(
ctx.GetPlace(), ctx,
input, dev_ctx,
filter, ctx.GetPlace(),
bias, input,
output_grad,
filter_grad,
input_grad,
ctx.InputName("Input") + ctx.InputName("Filter"));
// create mkldnn memory from input tensors (data/weights)
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (filter_grad) {
auto src_memory_p =
handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input);
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
output_grad);
// For convoluition with groups write filter grad into
// oneDNN buffer and then we reorder it into filter_grad tensor
int g = std::max(ctx.Attr<int>("groups"), 1);
auto diff_weights_memory_p =
g > 1 ? handler.AcquireDiffWeightsMemory()
: handler.AcquireDiffWeightsMemory(filter_grad);
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
conv_bwd_weights_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data
if (g > 1) {
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(filter->dtype()));
// for 3d conv with groups (six dimensional data reorder to goidhw)
// for 2d conv with groups (five dimensional data reorder to goihw)
// auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims();
dnnl::memory::format_tag out_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw
: dnnl::memory::format_tag::goihw;
platform::ReorderMKLDNNHandler handler(
weights_tz,
framework::TransToProtoVarType(filter->dtype()),
in_type,
mkldnn_engine);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p);
{
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream, *diff_weights_memory_p, *reorder_dst_memory_p);
astream.wait();
}
// So here we have a data in goihw , which can be interpreted as OIHW
// (OIDHW for conv3d)
// because filter_grad shape is set for OIHW (OIDHW for conv3d)
dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw;
filter_grad->set_mem_desc(
dnnl::memory::desc(phi::vectorize<int64_t>(filter_grad->dims()),
in_type,
target_format));
} else {
filter_grad->set_mem_desc(diff_weights_memory_p->get_desc());
}
}
if (input_grad) {
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorderFromDataPrimitive(
filter, filter,
ctx.Attr<int>("groups"), bias,
ctx.Attr<std::vector<int>>("strides").size() == 3U); output_grad,
filter_grad,
auto diff_dst_memory_p = input_grad,
handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( ctx.InputName("Input") + ctx.InputName("Filter"));
output_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); // create mkldnn memory from input tensors (data/weights)
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto conv_bwd_data_p = handler.AcquireBackwardPrimitive();
if (filter_grad) {
conv_bwd_data_p->execute(astream, auto src_memory_p =
{{DNNL_ARG_WEIGHTS, *weights_memory_p}, handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input);
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, auto diff_dst_memory_p =
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
astream.wait(); output_grad);
input_grad->set_mem_desc(diff_src_memory_p->get_desc()); // For convoluition with groups write filter grad into
} // oneDNN buffer and then we reorder it into filter_grad tensor
int g = std::max(ctx.Attr<int>("groups"), 1);
auto diff_weights_memory_p =
g > 1 ? handler.AcquireDiffWeightsMemory()
: handler.AcquireDiffWeightsMemory(filter_grad);
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
conv_bwd_weights_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on
// this data
if (g > 1) {
// in OneDNN groups in convolution are treated as separate
// dimension which is not the case in paddlepaddle
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(filter->dtype()));
// for 3d conv with groups (six dimensional data reorder to
// goidhw) for 2d conv with groups (five dimensional data reorder
// to goihw) auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims();
dnnl::memory::format_tag out_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw
: dnnl::memory::format_tag::goihw;
platform::ReorderMKLDNNHandler handler(
weights_tz,
framework::TransToProtoVarType(filter->dtype()),
in_type,
mkldnn_engine);
auto reorder_dst_memory_p = handler.AcquireDstMemory(
filter_grad, out_format, ctx.GetPlace());
auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p,
diff_weights_memory_p);
{
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream, *diff_weights_memory_p, *reorder_dst_memory_p);
astream.wait();
}
// So here we have a data in goihw , which can be interpreted as
// OIHW (OIDHW for conv3d) because filter_grad shape is set for
// OIHW (OIDHW for conv3d)
dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw;
filter_grad->set_mem_desc(dnnl::memory::desc(
phi::vectorize<int64_t>(filter_grad->dims()),
in_type,
target_format));
} else {
filter_grad->set_mem_desc(diff_weights_memory_p->get_desc());
}
}
if (input_grad) {
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorderFromDataPrimitive(
filter,
ctx.Attr<int>("groups"),
ctx.Attr<std::vector<int>>("strides").size() == 3U);
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
output_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad);
auto conv_bwd_data_p = handler.AcquireBackwardPrimitive();
conv_bwd_data_p->execute(astream,
{{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
input_grad->set_mem_desc(diff_src_memory_p->get_desc());
}
}));
} }
}; };
...@@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, REGISTER_OP_KERNEL(conv2d,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
FP32, ops::ConvMKLDNNOpKernel<float>,
ops::kConvMKLDNNFP32, ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<float, float>); ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d, REGISTER_OP_KERNEL(conv2d_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
BF16, ops::ConvMKLDNNGradOpKernel<float>,
ops::kConvMKLDNNFP32, ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL(depthwise_conv2d,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
MKLDNN, ::paddle::platform::CPUPlace,
::paddle::platform::CPUPlace, ops::ConvMKLDNNOpKernel<float>,
U8, ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<uint8_t, float>); ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, REGISTER_OP_KERNEL(depthwise_conv2d_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
U8WS8, ops::ConvMKLDNNGradOpKernel<float>,
ops::kConvMKLDNNINT8WS8, ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
ops::ConvMKLDNNOpKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL(conv3d,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
MKLDNN, ::paddle::platform::CPUPlace,
::paddle::platform::CPUPlace, ops::ConvMKLDNNOpKernel<float>);
S8,
ops::kConvMKLDNNINT8, REGISTER_OP_KERNEL(conv3d_grad,
ops::ConvMKLDNNOpKernel<int8_t, float>); MKLDNN,
::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, ops::ConvMKLDNNGradOpKernel<float>);
MKLDNN,
::paddle::platform::CPUPlace,
S8WS8,
ops::kConvMKLDNNINT8WS8,
ops::ConvMKLDNNOpKernel<int8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
depthwise_conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
...@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); ...@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d); USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32); USE_OP_DEVICE_KERNEL(conv2d, MKLDNN);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册