未验证 提交 56008aa1 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Pool softmax and LRN access to cache optimized (#32922)

上级 af89a943
......@@ -14,21 +14,104 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace framework {
class Tensor;
} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
template <typename T>
class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward> {
public:
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCachedNonBlocking()) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
}
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const Tensor* in_x,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet(
"is_test attribute should be set to False in training phase."));
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto dims = framework::vectorize<int64_t>(in_x->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
in_x->format());
auto diff_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
beta, k);
}
}
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(Tensor* workspace) {
T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->workspace_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
ptr, "@wrk_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
const Tensor* workspace) {
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_desc(),
platform::to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p");
}
};
template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -48,8 +131,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut");
platform::LRNMKLDNNHandler<T> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, ctx.OutputName("Out"));
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x,
ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out);
......@@ -87,34 +170,22 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRNGrad must use CPUPlace"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet(
"is_test attribute should be set to False in training phase."));
auto x = ctx.Input<Tensor>("X");
auto in_x = ctx.Input<Tensor>("X");
auto mid = ctx.Input<Tensor>("MidOut");
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto dims = paddle::framework::vectorize<int64_t>(x->dims());
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), in_x, out_grad,
in_x_grad, ctx.InputName("Out"));
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
out_grad->format(), dev_ctx,
ctx.GetPlace(), ctx.InputName("Out"));
auto src_memory = handler.AcquireSrcMemory(x);
auto src_memory = handler.AcquireSrcMemory(in_x);
auto workspace = handler.AcquireBackwardWorkspaceMemory(mid);
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
auto lrn_bwd = handler.AcquireBackwardPrimitive();
......@@ -125,8 +196,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_WORKSPACE, *workspace}});
astream.wait();
x_grad->set_layout(framework::DataLayout::kMKLDNN);
x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
in_x_grad->set_layout(framework::DataLayout::kMKLDNN);
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
}
};
} // namespace operators
......
......@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()),
unique_name)) {
if (!this->isCached()) {
if (!this->isCachedNonBlocking()) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor."));
......@@ -100,11 +100,10 @@ class PoolingMKLDNNHandler
const auto is_test = ctx.Attr<bool>("is_test");
const auto dt = framework::ToMKLDNNDataType(input->type());
const auto fmt = input->format();
const auto exclude_padding = ctx.Attr<bool>("exclusive");
const auto src_md = mkldnn::memory::desc(src_tz, dt, fmt);
const auto src_md = mkldnn::memory::desc(src_tz, dt, input->format());
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
......@@ -124,7 +123,7 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);
this->AcquireForwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptorNonBlocking(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
......@@ -200,6 +199,10 @@ class PoolingMKLDNNHandler
auto diff_dst_tz =
paddle::framework::vectorize<int64_t>(out_grad->dims());
const auto dt = framework::ToMKLDNNDataType(in_x->type());
auto src_md = mkldnn::memory::desc(src_tz, dt, in_x->format());
auto dst_md =
mkldnn::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any);
auto diff_dst_md = mkldnn::memory::desc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto diff_src_md =
......@@ -216,7 +219,18 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides);
const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireBackwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptorNonBlocking(
mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
......
......@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCached()) {
if (!this->isCachedNonBlocking()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
......@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
prop_kind::forward_scoring, md, axis);
}
}
......@@ -90,8 +90,10 @@ class SoftmaxMKLDNNHandler
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
prop_kind::forward_scoring, data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
diff_softmax_md, data_softmax_md, axis);
}
}
};
......
......@@ -126,13 +126,20 @@ class MKLDNNHandlerT {
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
bool isCachedNonBlocking() {
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
}
bool isBwdCached() {
const std::string key_pd = key_common_ + "@bwd_pd";
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@bwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
return (bwd_pd_ != nullptr);
}
// If your primitive descriptor requires attributes, pass them as a
......@@ -161,6 +168,20 @@ class MKLDNNHandlerT {
}
}
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptorNonBlocking(Arg&& first_arg,
Args&&... args) {
// This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
......@@ -182,6 +203,8 @@ class MKLDNNHandlerT {
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
// TODO(jczaja): After/if all ops can used xxxNonBlocking version
// then remove this one
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
......@@ -201,6 +224,25 @@ class MKLDNNHandlerT {
}
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptorNonBlocking(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) {
return std::static_pointer_cast<mkldnn::memory>(
......@@ -781,82 +823,6 @@ class ActivationMKLDNNHandler
}
};
template <typename T>
class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public:
LRNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
}
LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, unique_name)) {
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta,
k);
}
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
framework::Tensor* workspace) {
T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->workspace_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
ptr, "@wrk_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
const framework::Tensor* workspace) {
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
to_void_cast<T>(workspace_data),
"@bwd-wrk_mem_p");
}
};
template <typename T>
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
......
......@@ -63,4 +63,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册