未验证 提交 673bf719 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] disable caching for interpolate and batch Norm (#35030)

* - disabled interpolate onednn

* - compilation fix

* - draft of batch_norm cache disabling

* - fixes to UT
上级 a047c139
......@@ -35,22 +35,17 @@ using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast;
template <typename T>
class BatchNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> {
public:
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor *x,
const bool global_stats, const bool test_mode,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
const mkldnn::engine mkldnn_engine, const Tensor *x,
const bool global_stats, const bool test_mode)
: platform::MKLDNNHandlerNoCachingT<T,
mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
unique_name)) {
if (!this->isCached()) {
mkldnn_engine, ctx.GetPlace()) {
const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
......@@ -84,19 +79,14 @@ class BatchNormMKLDNNHandler
: mkldnn::prop_kind::forward_training,
md, epsilon, flags);
}
}
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place, const Tensor *in_x,
const Tensor *scale, const Tensor *out_grad,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
const mkldnn::engine mkldnn_engine, const Tensor *in_x,
const Tensor *scale, const Tensor *out_grad)
: platform::MKLDNNHandlerNoCachingT<T,
mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
unique_name)) {
if (!this->isBwdCached()) {
mkldnn_engine, ctx.GetPlace()) {
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input out_grad tensor"));
......@@ -119,8 +109,8 @@ class BatchNormMKLDNNHandler
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto diff_dst_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
......@@ -133,13 +123,9 @@ class BatchNormMKLDNNHandler
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
}
}
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
const Tensor *shift,
const bool is_test) {
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
if (scaleshift_memory == nullptr || !is_test) {
const Tensor *shift) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ(
......@@ -148,46 +134,42 @@ class BatchNormMKLDNNHandler
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
auto mem_p = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
// MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle());
auto mem_data_handle =
reinterpret_cast<T *>(scaleshift_memory->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return mem_p;
}
return scaleshift_memory;
}
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
diff_scaleshift_data,
"@diff_scaleshift_mem_p");
diff_scaleshift_data);
}
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
const framework::Tensor *mean) {
const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p");
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
to_void_cast<T>(mean_data));
}
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
T *mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p");
mean_data);
}
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
const framework::Tensor *variance) {
const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data),
"@variance_mem_p");
to_void_cast<T>(variance_data));
}
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
......@@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler
T *variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p");
variance_data);
}
};
......@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
ctx.GetPlace(), x, global_stats,
test_mode, ctx.OutputName("SavedMean"));
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats,
test_mode);
auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scale, shift, is_test);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive();
......@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale,
diff_y, ctx.InputName("SavedMean"));
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, scale, diff_y);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
......@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto mean_memory = handler.AcquireMeanMemory(batch_mean);
auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scale, shift, false);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
......
......@@ -30,27 +30,21 @@ using platform::to_void_cast;
template <typename T = float>
class InterpolateMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> {
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo,
const platform::MKLDNNDeviceContext& dev_ctx,
const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, Tensor* z,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
const Tensor* x, Tensor* z)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
engine, cpu_place) {
const auto src_x_tz = framework::vectorize(x->dims());
const auto dst_tz = framework::vectorize(z->dims());
const auto src_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, algo, src_md, dst_md);
}
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
algo, src_md, dst_md);
}
};
......@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X");
std::vector<float> scale_prior;
auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method");
......@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = framework::make_ddim(out_dims_vec);
z->mutable_data<T>(dim_out, ctx.GetPlace());
z->Resize(dim_out);
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine,
ctx.GetPlace(), x, z,
ctx.OutputName("Out"));
InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x,
z);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(z);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册