未验证 提交 673bf719 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] disable caching for interpolate and batch Norm (#35030)

* - disabled interpolate onednn

* - compilation fix

* - draft of batch_norm cache disabling

* - fixes to UT
上级 a047c139
...@@ -35,159 +35,141 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -35,159 +35,141 @@ using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T> template <typename T>
class BatchNormMKLDNNHandler class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> { mkldnn::batch_normalization_backward> {
public: public:
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor *x,
const mkldnn::engine mkldnn_engine, const bool global_stats, const bool test_mode)
platform::Place cpu_place, const Tensor *x, : platform::MKLDNNHandlerNoCachingT<T,
const bool global_stats, const bool test_mode, mkldnn::batch_normalization_forward,
const std::string &unique_name) mkldnn::batch_normalization_backward>(
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, mkldnn_engine, ctx.GetPlace()) {
mkldnn::batch_normalization_backward>( const float epsilon = ctx.Attr<float>("epsilon");
dev_ctx, dev_ctx.GetEngine(), cpu_place, const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
unique_name)) { std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
if (!this->isCached()) { "kAnyLayout", "kMKLDNN"};
const float epsilon = ctx.Attr<float>("epsilon"); PADDLE_ENFORCE_EQ(
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW", "Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
"kAnyLayout", "kMKLDNN"}; "But received %s.",
PADDLE_ENFORCE_EQ( DataLayout_error_msg[static_cast<int>(DataLayout::kMKLDNN)]));
x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_NE(
platform::errors::InvalidArgument( x->format(), MKLDNNMemoryFormat::undef,
"Wrong layout set for X tensor. Expected layout is `kMKLDNN`, " platform::errors::InvalidArgument("Wrong format set for X tensor"));
"But received %s.",
DataLayout_error_msg[static_cast<int>(DataLayout::kMKLDNN)])); auto src_tz = paddle::framework::vectorize(x->dims());
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef, // Flags are added by bitwise OR operation
platform::errors::InvalidArgument("Wrong format set for X tensor")); auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats)
auto src_tz = paddle::framework::vectorize(x->dims()); flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
// Flags are added by bitwise OR operation flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats) auto md = mkldnn::memory::desc(
flags |= mkldnn::normalization_flags::use_global_stats; // 010 src_tz, platform::MKLDNNGetDataType<T>(),
if (fuse_with_relu && test_mode) platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
this->AcquireForwardPrimitiveDescriptor(
auto md = mkldnn::memory::desc( global_stats == true ? mkldnn::prop_kind::forward_scoring
src_tz, platform::MKLDNNGetDataType<T>(), : mkldnn::prop_kind::forward_training,
platform::MKLDNNFormatForSize(src_tz.size(), x->format())); md, epsilon, flags);
this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training,
md, epsilon, flags);
}
} }
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor *in_x,
platform::Place cpu_place, const Tensor *in_x, const Tensor *scale, const Tensor *out_grad)
const Tensor *scale, const Tensor *out_grad, : platform::MKLDNNHandlerNoCachingT<T,
const std::string &unique_name) mkldnn::batch_normalization_forward,
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, mkldnn::batch_normalization_backward>(
mkldnn::batch_normalization_backward>( mkldnn_engine, ctx.GetPlace()) {
dev_ctx, dev_ctx.GetEngine(), cpu_place, PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), platform::errors::InvalidArgument(
unique_name)) { "Wrong layout set for Input out_grad tensor"));
if (!this->isBwdCached()) { PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Wrong format set for Input out_grad tensor"));
"Wrong layout set for Input out_grad tensor"));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
platform::errors::InvalidArgument( auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
"Wrong format set for Input out_grad tensor")); PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims()); platform::errors::InvalidArgument(
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims()); "Dims of scale tensor must be 1, but received scale's size is %d",
PADDLE_ENFORCE_EQ( scale_tz.size()));
scale_tz.size(), 1,
platform::errors::InvalidArgument( MKLDNNMemoryFormat diff_fmt =
"Dims of scale tensor must be 1, but received scale's size is %d", platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
scale_tz.size()));
MKLDNNMemoryFormat src_fmt =
MKLDNNMemoryFormat diff_fmt = platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
auto dims = framework::vectorize(in_x->dims());
MKLDNNMemoryFormat src_fmt = auto diff_dst_md =
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format()); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
auto dims = framework::vectorize(in_x->dims()); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
auto diff_dst_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); const float epsilon = ctx.Attr<float>("epsilon");
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt); this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, src_md, epsilon,
const float epsilon = ctx.Attr<float>("epsilon"); mkldnn::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptor( mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::prop_kind::forward_training, src_md, epsilon, mkldnn::normalization_flags::use_scale_shift);
mkldnn::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
}
} }
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale, std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
const Tensor *shift, const Tensor *shift) {
const bool is_test) { auto scale_tz = paddle::framework::vectorize(scale->dims());
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p"); const unsigned int C = scale_tz[0];
if (scaleshift_memory == nullptr || !is_test) { PADDLE_ENFORCE_EQ(
auto scale_tz = paddle::framework::vectorize(scale->dims()); scale_tz.size(), 1,
const unsigned int C = scale_tz[0]; platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ( "Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size(), 1, scale_tz.size()));
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d", auto scaleshift_memory =
scale_tz.size())); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto mem_p = this->AcquireMemoryFromPrimitive( // MKLDNN requires a single piece of memory for scale and shift/bias data
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); auto mem_data_handle =
reinterpret_cast<T *>(scaleshift_memory->get_data_handle());
// MKLDNN requires a single piece of memory for scale and shift/bias data std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle()); std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return mem_p;
}
return scaleshift_memory; return scaleshift_memory;
} }
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory( std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) { T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
diff_scaleshift_data, diff_scaleshift_data);
"@diff_scaleshift_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory( std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
const framework::Tensor *mean) { const framework::Tensor *mean) {
const T *mean_data = mean->data<T>(); const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p"); to_void_cast<T>(mean_data));
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) { std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
T *mean_data = mean->mutable_data<T>(this->place_, T *mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size()); this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p"); mean_data);
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
const framework::Tensor *variance) { const framework::Tensor *variance) {
const T *variance_data = variance->data<T>(); const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data), to_void_cast<T>(variance_data));
"@variance_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
...@@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler ...@@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler
T *variance_data = variance->mutable_data<T>( T *variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p"); variance_data);
} }
}; };
...@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto *batch_mean = ctx.Output<Tensor>("SavedMean"); auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats,
ctx.GetPlace(), x, global_stats, test_mode);
test_mode, ctx.OutputName("SavedMean"));
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory = auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
handler.AcquireScaleShiftMemory(scale, shift, is_test);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive(); auto batch_norm_p = handler.AcquireForwardPrimitive();
...@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale, BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, scale, diff_y);
diff_y, ctx.InputName("SavedMean"));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = paddle::framework::vectorize(scale->dims())[0]; const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
...@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto mean_memory = handler.AcquireMeanMemory(batch_mean); auto mean_memory = handler.AcquireMeanMemory(batch_mean);
auto variance_memory = handler.AcquireVarianceMemory(batch_variance); auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory = auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
handler.AcquireScaleShiftMemory(scale, shift, false);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
......
...@@ -30,27 +30,21 @@ using platform::to_void_cast; ...@@ -30,27 +30,21 @@ using platform::to_void_cast;
template <typename T = float> template <typename T = float>
class InterpolateMKLDNNHandler class InterpolateMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
public: public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo, InterpolateMKLDNNHandler(const dnnl::algorithm algo,
const platform::MKLDNNDeviceContext& dev_ctx,
const dnnl::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, Tensor* z, const Tensor* x, Tensor* z)
const std::string& uniq_name) : platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>( engine, cpu_place) {
dev_ctx, engine, cpu_place, const auto src_x_tz = framework::vectorize(x->dims());
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), const auto dst_tz = framework::vectorize(z->dims());
uniq_name)) { const auto src_md = dnnl::memory::desc(
if (!this->isCached()) { src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto src_x_tz = framework::vectorize(x->dims()); const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
const auto dst_tz = framework::vectorize(z->dims()); MKLDNNMemoryFormat::any);
const auto src_md = dnnl::memory::desc( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format()); algo, src_md, dst_md);
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, algo, src_md, dst_md);
}
} }
}; };
...@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
std::vector<float> scale_prior;
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
...@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
auto out_dims_vec = ComputeOutputShape(ctx); auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = framework::make_ddim(out_dims_vec); framework::DDim dim_out = framework::make_ddim(out_dims_vec);
z->mutable_data<T>(dim_out, ctx.GetPlace()); z->Resize(dim_out);
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine, InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x,
ctx.GetPlace(), x, z, z);
ctx.OutputName("Out"));
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(z); auto dst_memory_p = handler.AcquireDstMemory(z);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册