未验证 提交 b490e41c 编写于 作者: A Adam 提交者: GitHub

Add isCached() mechanism for BatchNorm and LRN oneDNN operators (#24798)

* Add isCached() mechanism for BatchNorm and LRN oneDNN operators
test=develop

* Formatting fix
test=develop
上级 9d66385f
......@@ -31,23 +31,46 @@ class BatchNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> {
public:
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const mkldnn::normalization_flags &flags,
const bool &global_stats, const MKLDNNMemoryFormat fmt,
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place,
const std::string &uniq_name)
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor *x,
const bool global_stats, const bool test_mode,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
platform::CreateKey(framework::vectorize(x->dims()), unique_name)) {
if (!this->isCached()) {
const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
auto src_tz = paddle::framework::vectorize(x->dims());
// Flags are added by bitwise OR operation
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
auto md = mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(),
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training,
md, epsilon, flags);
}
}
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const mkldnn::normalization_flags &flags,
const MKLDNNMemoryFormat diff_fmt,
......@@ -68,9 +91,30 @@ class BatchNormMKLDNNHandler
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags);
}
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) {
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), scaleshift_data, "@scaleshift_mem_p");
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
const Tensor *shift,
const bool is_test) {
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
if (scaleshift_memory == nullptr || !is_test) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
auto mem_p = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
// MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return mem_p;
}
return scaleshift_memory;
}
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
......@@ -115,64 +159,30 @@ template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
bool global_stats = test_mode || use_global_stats;
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const bool test_mode = is_test && (!trainable_stats);
const bool global_stats = test_mode || use_global_stats;
const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for X tensor");
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
shift->data<T>() + C);
// Flags are added by bitwise OR operation
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, flags, global_stats,
platform::MKLDNNFormatForSize(src_tz.size(), x->format()), dev_ctx,
ctx.GetPlace(), ctx.OutputName("SavedMean"));
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
ctx.GetPlace(), x, global_stats,
test_mode, ctx.OutputName("SavedMean"));
auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scaleshift_data.data());
handler.AcquireScaleShiftMemory(scale, shift, is_test);
auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive();
......@@ -206,6 +216,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
astream.wait();
if (!global_stats) {
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
const float momentum = ctx.Attr<float>("momentum");
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(
......@@ -273,11 +289,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * C;
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
scaleshift_data.reserve(scaleshift_size);
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
shift->data<T>() + C);
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
......@@ -286,7 +297,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scaleshift_data.data());
handler.AcquireScaleShiftMemory(scale, shift, false);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
......
......@@ -33,29 +33,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRN must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut");
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize<int64_t>(x->dims());
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
is_test, dev_ctx, ctx.GetPlace(),
ctx.OutputName("Out"));
platform::LRNMKLDNNHandler<T> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out);
......@@ -77,6 +64,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const float k = ctx.Attr<float>("k");
e_mid = e_mid.constant(k);
mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
......
......@@ -162,7 +162,7 @@ class MKLDNNHandlerT {
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr, const std::string& suffix) {
auto local_key = key_ + suffix;
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
......@@ -174,6 +174,24 @@ class MKLDNNHandlerT {
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>(
dev_ctx_.GetBlob(local_key));
}
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
platform::Place place_;
......@@ -535,22 +553,40 @@ template <typename T>
class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public:
LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat fmt, bool is_test,
LRNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name)
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) {
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
}
LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
const float alpha, const float beta, const float k,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册