未验证 提交 673bf719 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] disable caching for interpolate and batch Norm (#35030)

* - disabled interpolate onednn

* - compilation fix

* - draft of batch_norm cache disabling

* - fixes to UT
上级 a047c139
...@@ -35,22 +35,17 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -35,22 +35,17 @@ using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T> template <typename T>
class BatchNormMKLDNNHandler class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> { mkldnn::batch_normalization_backward> {
public: public:
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor *x,
const mkldnn::engine mkldnn_engine, const bool global_stats, const bool test_mode)
platform::Place cpu_place, const Tensor *x, : platform::MKLDNNHandlerNoCachingT<T,
const bool global_stats, const bool test_mode, mkldnn::batch_normalization_forward,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, mkldnn_engine, ctx.GetPlace()) {
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
unique_name)) {
if (!this->isCached()) {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
...@@ -84,19 +79,14 @@ class BatchNormMKLDNNHandler ...@@ -84,19 +79,14 @@ class BatchNormMKLDNNHandler
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
md, epsilon, flags); md, epsilon, flags);
} }
}
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor *in_x,
platform::Place cpu_place, const Tensor *in_x, const Tensor *scale, const Tensor *out_grad)
const Tensor *scale, const Tensor *out_grad, : platform::MKLDNNHandlerNoCachingT<T,
const std::string &unique_name) mkldnn::batch_normalization_forward,
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, mkldnn_engine, ctx.GetPlace()) {
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong layout set for Input out_grad tensor")); "Wrong layout set for Input out_grad tensor"));
...@@ -119,8 +109,8 @@ class BatchNormMKLDNNHandler ...@@ -119,8 +109,8 @@ class BatchNormMKLDNNHandler
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format()); platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
auto dims = framework::vectorize(in_x->dims()); auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = mkldnn::memory::desc( auto diff_dst_md =
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
...@@ -133,13 +123,9 @@ class BatchNormMKLDNNHandler ...@@ -133,13 +123,9 @@ class BatchNormMKLDNNHandler
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift); mkldnn::normalization_flags::use_scale_shift);
} }
}
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale, std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
const Tensor *shift, const Tensor *shift) {
const bool is_test) {
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
if (scaleshift_memory == nullptr || !is_test) {
auto scale_tz = paddle::framework::vectorize(scale->dims()); auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0]; const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -148,46 +134,42 @@ class BatchNormMKLDNNHandler ...@@ -148,46 +134,42 @@ class BatchNormMKLDNNHandler
"Dims of scale tensor must be 1, but received scale's size is %d", "Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size())); scale_tz.size()));
auto mem_p = this->AcquireMemoryFromPrimitive( auto scaleshift_memory =
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle()); auto mem_data_handle =
reinterpret_cast<T *>(scaleshift_memory->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle); std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C); std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return mem_p;
}
return scaleshift_memory; return scaleshift_memory;
} }
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory( std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) { T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
diff_scaleshift_data, diff_scaleshift_data);
"@diff_scaleshift_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory( std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
const framework::Tensor *mean) { const framework::Tensor *mean) {
const T *mean_data = mean->data<T>(); const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p"); to_void_cast<T>(mean_data));
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) { std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
T *mean_data = mean->mutable_data<T>(this->place_, T *mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size()); this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p"); mean_data);
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
const framework::Tensor *variance) { const framework::Tensor *variance) {
const T *variance_data = variance->data<T>(); const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data), to_void_cast<T>(variance_data));
"@variance_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
...@@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler ...@@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler
T *variance_data = variance->mutable_data<T>( T *variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p"); variance_data);
} }
}; };
...@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto *batch_mean = ctx.Output<Tensor>("SavedMean"); auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats,
ctx.GetPlace(), x, global_stats, test_mode);
test_mode, ctx.OutputName("SavedMean"));
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory = auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
handler.AcquireScaleShiftMemory(scale, shift, is_test);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive(); auto batch_norm_p = handler.AcquireForwardPrimitive();
...@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale, BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, scale, diff_y);
diff_y, ctx.InputName("SavedMean"));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = paddle::framework::vectorize(scale->dims())[0]; const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
...@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto mean_memory = handler.AcquireMeanMemory(batch_mean); auto mean_memory = handler.AcquireMeanMemory(batch_mean);
auto variance_memory = handler.AcquireVarianceMemory(batch_variance); auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory = auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
handler.AcquireScaleShiftMemory(scale, shift, false);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
......
...@@ -30,27 +30,21 @@ using platform::to_void_cast; ...@@ -30,27 +30,21 @@ using platform::to_void_cast;
template <typename T = float> template <typename T = float>
class InterpolateMKLDNNHandler class InterpolateMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
public: public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo, InterpolateMKLDNNHandler(const dnnl::algorithm algo,
const platform::MKLDNNDeviceContext& dev_ctx,
const dnnl::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, Tensor* z, const Tensor* x, Tensor* z)
const std::string& uniq_name) : platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>( engine, cpu_place) {
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
const auto src_x_tz = framework::vectorize(x->dims()); const auto src_x_tz = framework::vectorize(x->dims());
const auto dst_tz = framework::vectorize(z->dims()); const auto dst_tz = framework::vectorize(z->dims());
const auto src_md = dnnl::memory::desc( const auto src_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format()); src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(), const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
dnnl::prop_kind::forward_inference, algo, src_md, dst_md); algo, src_md, dst_md);
}
} }
}; };
...@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
std::vector<float> scale_prior;
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
...@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
auto out_dims_vec = ComputeOutputShape(ctx); auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = framework::make_ddim(out_dims_vec); framework::DDim dim_out = framework::make_ddim(out_dims_vec);
z->mutable_data<T>(dim_out, ctx.GetPlace()); z->Resize(dim_out);
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine, InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x,
ctx.GetPlace(), x, z, z);
ctx.OutputName("Out"));
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(z); auto dst_memory_p = handler.AcquireDstMemory(z);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册