未验证 提交 b490e41c 编写于 作者: A Adam 提交者: GitHub

Add isCached() mechanism for BatchNorm and LRN oneDNN operators (#24798)

* Add isCached() mechanism for BatchNorm and LRN oneDNN operators
test=develop

* Formatting fix
test=develop
上级 9d66385f
...@@ -31,22 +31,45 @@ class BatchNormMKLDNNHandler ...@@ -31,22 +31,45 @@ class BatchNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> { mkldnn::batch_normalization_backward> {
public: public:
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const mkldnn::normalization_flags &flags,
const bool &global_stats, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext &dev_ctx, const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place, const mkldnn::engine mkldnn_engine,
const std::string &uniq_name) platform::Place cpu_place, const Tensor *x,
const bool global_stats, const bool test_mode,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) { platform::CreateKey(framework::vectorize(x->dims()), unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); if (!this->isCached()) {
const float epsilon = ctx.Attr<float>("epsilon");
this->AcquireForwardPrimitiveDescriptor( const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
global_stats == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training, PADDLE_ENFORCE_EQ(
md, epsilon, flags); x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
auto src_tz = paddle::framework::vectorize(x->dims());
// Flags are added by bitwise OR operation
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
auto md = mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(),
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training,
md, epsilon, flags);
}
} }
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon, BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const mkldnn::normalization_flags &flags, const mkldnn::normalization_flags &flags,
...@@ -68,9 +91,30 @@ class BatchNormMKLDNNHandler ...@@ -68,9 +91,30 @@ class BatchNormMKLDNNHandler
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags); mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags);
} }
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) { std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
return this->AcquireMemoryFromPrimitive( const Tensor *shift,
this->fwd_pd_->weights_desc(), scaleshift_data, "@scaleshift_mem_p"); const bool is_test) {
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
if (scaleshift_memory == nullptr || !is_test) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
auto mem_p = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
// MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return mem_p;
}
return scaleshift_memory;
} }
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory( std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
...@@ -115,64 +159,30 @@ template <typename T> ...@@ -115,64 +159,30 @@ template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const float momentum = ctx.Attr<float>("momentum"); const auto &mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics"); const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats); const bool test_mode = is_test && (!trainable_stats);
const bool global_stats = test_mode || use_global_stats;
bool global_stats = test_mode || use_global_stats;
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *batch_mean = ctx.Output<Tensor>("SavedMean"); auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
"Wrong layout set for X tensor"); ctx.GetPlace(), x, global_stats,
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef, test_mode, ctx.OutputName("SavedMean"));
"Wrong format set for X tensor");
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
shift->data<T>() + C);
// Flags are added by bitwise OR operation
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, flags, global_stats,
platform::MKLDNNFormatForSize(src_tz.size(), x->format()), dev_ctx,
ctx.GetPlace(), ctx.OutputName("SavedMean"));
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory = auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scaleshift_data.data()); handler.AcquireScaleShiftMemory(scale, shift, is_test);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive(); auto batch_norm_p = handler.AcquireForwardPrimitive();
...@@ -206,6 +216,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -206,6 +216,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
astream.wait(); astream.wait();
if (!global_stats) { if (!global_stats) {
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
const float momentum = ctx.Attr<float>("momentum");
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
// mkldnn only compute stats for current batch // mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib // so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e( EigenVectorArrayMap<T> batch_mean_e(
...@@ -273,11 +289,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -273,11 +289,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * C; const size_t scaleshift_size = 2 * C;
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
scaleshift_data.reserve(scaleshift_size);
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
shift->data<T>() + C);
std::vector<T> diff_scaleshift_data; std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size); diff_scaleshift_data.reserve(scaleshift_size);
...@@ -286,7 +297,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -286,7 +297,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto variance_memory = handler.AcquireVarianceMemory(batch_variance); auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory = auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scaleshift_data.data()); handler.AcquireScaleShiftMemory(scale, shift, false);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
......
...@@ -33,29 +33,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -33,29 +33,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRN must use CPUPlace")); "Operator DNNL LRN must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out"); auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut"); auto mid = ctx.Output<Tensor>("MidOut");
const int n = ctx.Attr<int>("n"); platform::LRNMKLDNNHandler<T> handler(
// MKL-DNN implements LRN in a caffe way: ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, ctx.OutputName("Out"));
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize<int64_t>(x->dims());
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
is_test, dev_ctx, ctx.GetPlace(),
ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out); auto dst_memory = handler.AcquireDstMemory(out);
...@@ -77,6 +64,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -77,6 +64,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Disable checking mid in unit tests (Require API change) // TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace()); mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const float k = ctx.Attr<float>("k");
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
mid->set_format(platform::GetMKLDNNFormat(*dst_memory)); mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
......
...@@ -162,7 +162,7 @@ class MKLDNNHandlerT { ...@@ -162,7 +162,7 @@ class MKLDNNHandlerT {
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr, const std::string& suffix) { mkldnn::memory::desc md, void* ptr, const std::string& suffix) {
auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
...@@ -174,6 +174,24 @@ class MKLDNNHandlerT { ...@@ -174,6 +174,24 @@ class MKLDNNHandlerT {
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>(
dev_ctx_.GetBlob(local_key));
}
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; mkldnn::engine engine_;
platform::Place place_; platform::Place place_;
...@@ -535,21 +553,39 @@ template <typename T> ...@@ -535,21 +553,39 @@ template <typename T>
class LRNMKLDNNHandler class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> { : public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public: public:
LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n, LRNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name) const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dims, unique_name)) { platform::CreateKey(framework::vectorize(input->dims()),
auto src_md = unique_name)) {
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); if (!this->isCached()) {
this->AcquireForwardPrimitiveDescriptor( const int n = ctx.Attr<int>("n");
is_test ? mkldnn::prop_kind::forward_inference // MKL-DNN implements LRN in a caffe way:
: mkldnn::prop_kind::forward_training, // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); // Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
} }
LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n, LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册