未验证 提交 b2727020 编写于 作者: J jakpiase 提交者: GitHub

added conv and conv_tranpose support for md (#44677)

上级 6506668e
...@@ -24,13 +24,13 @@ namespace paddle { ...@@ -24,13 +24,13 @@ namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, inline MKLDNNMemoryFormat GetWeightsFormat(const int groups,
const int groups,
const bool is_conv3d) { const bool is_conv3d) {
if (is_conv3d) { if (is_conv3d) {
return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw; return (groups == 1) ? MKLDNNMemoryFormat::oidhw
: MKLDNNMemoryFormat::goidhw;
} else { } else {
return (groups == 1) ? format : MKLDNNMemoryFormat::goihw; return (groups == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw;
} }
} }
...@@ -98,10 +98,6 @@ class ConvMKLDNNHandlerT ...@@ -98,10 +98,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d.", "The input tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
input->layout())); input->layout()));
PADDLE_ENFORCE_NE(input->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), filter->layout(),
...@@ -110,10 +106,6 @@ class ConvMKLDNNHandlerT ...@@ -110,10 +106,6 @@ class ConvMKLDNNHandlerT
"The Filter tensor's layout should be %d, but got %d.", "The Filter tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
filter->layout())); filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Filter tensor"));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
input->dims().size(), input->dims().size(),
...@@ -153,10 +145,6 @@ class ConvMKLDNNHandlerT ...@@ -153,10 +145,6 @@ class ConvMKLDNNHandlerT
"The Bias tensor's layout should be %d, but got %d.", "The Bias tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
bias->layout())); bias->layout()));
PADDLE_ENFORCE_NE(bias->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(bias->dims().size(), PADDLE_ENFORCE_EQ(bias->dims().size(),
1, 1,
...@@ -307,10 +295,6 @@ class ConvMKLDNNHandlerT ...@@ -307,10 +295,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d.", "The input tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
in->layout())); in->layout()));
PADDLE_ENFORCE_NE(in->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), filter->layout(),
...@@ -319,10 +303,6 @@ class ConvMKLDNNHandlerT ...@@ -319,10 +303,6 @@ class ConvMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d.", "The filter tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
filter->layout())); filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Filter tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_grad->layout(), out_grad->layout(),
...@@ -331,10 +311,6 @@ class ConvMKLDNNHandlerT ...@@ -331,10 +311,6 @@ class ConvMKLDNNHandlerT
"The output_grad tensor's layout should be %d, but got %d.", "The output_grad tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
out_grad->layout())); out_grad->layout()));
PADDLE_ENFORCE_NE(out_grad->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), ctx.Attr<bool>("is_test"),
...@@ -596,10 +572,10 @@ class ConvMKLDNNHandlerT ...@@ -596,10 +572,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims()); auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md =
weights_tz, platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<K>(), platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(filter->format(), groups, is_conv3d)); GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, user_src_md,
...@@ -660,12 +636,11 @@ class ConvMKLDNNHandlerT ...@@ -660,12 +636,11 @@ class ConvMKLDNNHandlerT
auto user_mem_p = this->AcquireMemory(user_key_suffix); auto user_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_mem_p) { if (!user_mem_p) {
auto user_mem_md =
platform::MKLDNNMemDesc(phi::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(),
in_mem->format());
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem); in_mem->mem_desc(),
mem_md,
platform::to_void_cast<T>(in_mem_data),
key_mem);
} else { } else {
const std::string target_key_suffix{key_mem_target}; const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix); const auto target_mem_p = this->AcquireMemory(target_key_suffix);
...@@ -694,10 +669,10 @@ class ConvMKLDNNHandlerT ...@@ -694,10 +669,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims()); auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md =
weights_tz, platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<K>(), platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(filter->format(), groups, is_conv3d)); GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, user_src_md,
...@@ -713,10 +688,10 @@ class ConvMKLDNNHandlerT ...@@ -713,10 +688,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims()); auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md =
weights_tz, platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<T>(), platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(filter->format(), groups, is_conv3d)); GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, user_src_md,
...@@ -747,13 +722,9 @@ class ConvMKLDNNHandlerT ...@@ -747,13 +722,9 @@ class ConvMKLDNNHandlerT
LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype(); LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype();
} }
const K_Bias* bias_data = bias->data<K_Bias>(); const K_Bias* bias_data = bias->data<K_Bias>();
auto user_bias_md =
platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
platform::MKLDNNGetDataType<K_Bias>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_bias_md, bias->mem_desc(),
this->fwd_pd_->bias_desc(), this->fwd_pd_->bias_desc(),
platform::to_void_cast<K_Bias>(bias_data), platform::to_void_cast<K_Bias>(bias_data),
"@bias_mem_p", "@bias_mem_p",
...@@ -776,22 +747,16 @@ class ConvMKLDNNHandlerT ...@@ -776,22 +747,16 @@ class ConvMKLDNNHandlerT
residual_mem_p->set_data_handle(residual_data); residual_mem_p->set_data_handle(residual_data);
return residual_mem_p; return residual_mem_p;
} else { } else {
auto user_residual_md = platform::MKLDNNMemDesc( return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(),
phi::vectorize(residual_param->dims()), residual_data,
framework::ToMKLDNNDataType( "@user_residual_data_mem_p");
framework::TransToProtoVarType(residual_param->dtype())),
residual_param->format());
return this->AcquireMemoryFromPrimitive(
user_residual_md, residual_data, "@user_residual_data_mem_p");
} }
} }
std::shared_ptr<dnnl::memory> AcquireDstMemoryWithResidual( std::shared_ptr<dnnl::memory> AcquireDstMemoryWithResidual(
framework::Tensor* output, const framework::Tensor* residual_param) { framework::Tensor* output, const framework::Tensor* residual_param) {
std::shared_ptr<dnnl::memory> dst_memory_p; std::shared_ptr<dnnl::memory> dst_memory_p;
if (residual_param->format() != if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) {
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param); auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output); dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p); this->AcquireReorder(residual_memory_p, dst_memory_p);
...@@ -903,8 +868,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -903,8 +868,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
conv_p->execute(astream, args); conv_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN); output->set_mem_desc(dst_memory_p->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T_out> template <typename T_out>
...@@ -1018,8 +982,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -1018,8 +982,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_layout(framework::DataLayout::kMKLDNN); output->set_mem_desc(dst_memory_p->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
...@@ -1078,7 +1041,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1078,7 +1041,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
// TODO(grygielski) why no bias_diff?
conv_bwd_weights_p->execute( conv_bwd_weights_p->execute(
astream, astream,
{{DNNL_ARG_SRC, *src_memory_p}, {{DNNL_ARG_SRC, *src_memory_p},
......
...@@ -59,11 +59,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -59,11 +59,6 @@ class ConvTransposeMKLDNNHandlerT
DataLayout::kMKLDNN, DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong layout = %d for Input tensor.", input->layout())); "Got wrong layout = %d for Input tensor.", input->layout()));
PADDLE_ENFORCE_NE(input->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor. The input "
"format is undefined."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), filter->layout(),
...@@ -72,10 +67,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -72,10 +67,6 @@ class ConvTransposeMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d.", "The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, DataLayout::kMKLDNN,
filter->layout())); filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong formats for Filter tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims().size(), input->dims().size(),
...@@ -98,10 +89,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -98,10 +89,6 @@ class ConvTransposeMKLDNNHandlerT
"The bias tensor's laytout should be %d, but got %d.", "The bias tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN, DataLayout::kMKLDNN,
bias->layout())); bias->layout()));
PADDLE_ENFORCE_NE(bias->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias->dims().size(), bias->dims().size(),
...@@ -233,11 +220,8 @@ class ConvTransposeMKLDNNHandlerT ...@@ -233,11 +220,8 @@ class ConvTransposeMKLDNNHandlerT
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()),
platform::MKLDNNGetDataType<T>(),
input->format());
return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>:: return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
AcquireMemoryWithReorder(user_src_md, AcquireMemoryWithReorder(input->mem_desc(),
this->fwd_pd_->src_desc(), this->fwd_pd_->src_desc(),
platform::to_void_cast<T>(input_data)); platform::to_void_cast<T>(input_data));
} }
...@@ -427,8 +411,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -427,8 +411,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args); conv_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mem_desc(dst_memory_p->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册