未验证 提交 b2727020 编写于 作者: J jakpiase 提交者: GitHub

added conv and conv_tranpose support for md (#44677)

上级 6506668e
......@@ -24,13 +24,13 @@ namespace paddle {
namespace operators {
namespace {
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
const int groups,
inline MKLDNNMemoryFormat GetWeightsFormat(const int groups,
const bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
return (groups == 1) ? MKLDNNMemoryFormat::oidhw
: MKLDNNMemoryFormat::goidhw;
} else {
return (groups == 1) ? format : MKLDNNMemoryFormat::goihw;
return (groups == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw;
}
}
......@@ -98,10 +98,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
input->layout()));
PADDLE_ENFORCE_NE(input->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ(
filter->layout(),
......@@ -110,10 +106,6 @@ class ConvMKLDNNHandlerT
"The Filter tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Filter tensor"));
PADDLE_ENFORCE_GE(
input->dims().size(),
......@@ -153,10 +145,6 @@ class ConvMKLDNNHandlerT
"The Bias tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
bias->layout()));
PADDLE_ENFORCE_NE(bias->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(bias->dims().size(),
1,
......@@ -307,10 +295,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
in->layout()));
PADDLE_ENFORCE_NE(in->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor."));
PADDLE_ENFORCE_EQ(
filter->layout(),
......@@ -319,10 +303,6 @@ class ConvMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Filter tensor."));
PADDLE_ENFORCE_EQ(
out_grad->layout(),
......@@ -331,10 +311,6 @@ class ConvMKLDNNHandlerT
"The output_grad tensor's layout should be %d, but got %d.",
framework::DataLayout::kMKLDNN,
out_grad->layout()));
PADDLE_ENFORCE_NE(out_grad->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"),
......@@ -596,10 +572,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz,
platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
auto user_src_md =
platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md,
......@@ -660,12 +636,11 @@ class ConvMKLDNNHandlerT
auto user_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_mem_p) {
auto user_mem_md =
platform::MKLDNNMemDesc(phi::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(),
in_mem->format());
return this->AcquireMemoryWithReorder(
user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem);
in_mem->mem_desc(),
mem_md,
platform::to_void_cast<T>(in_mem_data),
key_mem);
} else {
const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix);
......@@ -694,10 +669,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz,
platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
auto user_src_md =
platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md,
......@@ -713,10 +688,10 @@ class ConvMKLDNNHandlerT
auto weights_tz = phi::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz,
platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
auto user_src_md =
platform::MKLDNNMemDesc(weights_tz,
platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md,
......@@ -747,13 +722,9 @@ class ConvMKLDNNHandlerT
LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype();
}
const K_Bias* bias_data = bias->data<K_Bias>();
auto user_bias_md =
platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
platform::MKLDNNGetDataType<K_Bias>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md,
bias->mem_desc(),
this->fwd_pd_->bias_desc(),
platform::to_void_cast<K_Bias>(bias_data),
"@bias_mem_p",
......@@ -776,22 +747,16 @@ class ConvMKLDNNHandlerT
residual_mem_p->set_data_handle(residual_data);
return residual_mem_p;
} else {
auto user_residual_md = platform::MKLDNNMemDesc(
phi::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(
framework::TransToProtoVarType(residual_param->dtype())),
residual_param->format());
return this->AcquireMemoryFromPrimitive(
user_residual_md, residual_data, "@user_residual_data_mem_p");
return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(),
residual_data,
"@user_residual_data_mem_p");
}
}
std::shared_ptr<dnnl::memory> AcquireDstMemoryWithResidual(
framework::Tensor* output, const framework::Tensor* residual_param) {
std::shared_ptr<dnnl::memory> dst_memory_p;
if (residual_param->format() !=
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p);
......@@ -903,8 +868,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
conv_p->execute(astream, args);
astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
output->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T_out>
......@@ -1018,8 +982,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
output->set_mem_desc(dst_memory_p->get_desc());
}
};
......@@ -1078,7 +1041,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
// TODO(grygielski) why no bias_diff?
conv_bwd_weights_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory_p},
......
......@@ -59,11 +59,6 @@ class ConvTransposeMKLDNNHandlerT
DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Got wrong layout = %d for Input tensor.", input->layout()));
PADDLE_ENFORCE_NE(input->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor. The input "
"format is undefined."));
PADDLE_ENFORCE_EQ(
filter->layout(),
......@@ -72,10 +67,6 @@ class ConvTransposeMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN,
filter->layout()));
PADDLE_ENFORCE_NE(filter->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong formats for Filter tensor."));
PADDLE_ENFORCE_EQ(
input->dims().size(),
......@@ -98,10 +89,6 @@ class ConvTransposeMKLDNNHandlerT
"The bias tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN,
bias->layout()));
PADDLE_ENFORCE_NE(bias->format(),
MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(
bias->dims().size(),
......@@ -233,11 +220,8 @@ class ConvTransposeMKLDNNHandlerT
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()),
platform::MKLDNNGetDataType<T>(),
input->format());
return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
AcquireMemoryWithReorder(user_src_md,
AcquireMemoryWithReorder(input->mem_desc(),
this->fwd_pd_->src_desc(),
platform::to_void_cast<T>(input_data));
}
......@@ -427,8 +411,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
output->set_mem_desc(dst_memory_p->get_desc());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册