未验证 提交 dfdb0359 编写于 作者: J Jacek Czaja 提交者: GitHub

- Disabling oneDNN inplace pass (#30588)

上级 430f8449
...@@ -224,12 +224,11 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -224,12 +224,11 @@ void CpuPassStrategy::EnableMKLDNN() {
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", "batch_norm_act_fuse_pass",
#ifndef _WIN32
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
"mkldnn_inplace_pass", // This pass should be activated after //"mkldnn_inplace_pass", // This pass should be activated after
// fuses // fuses. Disabled by default due to
#endif // little gain and lots of problems
})) { })) {
passes_.push_back(pass); passes_.push_back(pass);
} }
......
...@@ -99,17 +99,17 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -99,17 +99,17 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
"5, or 6, but now the dimension size is", "5, or 6, but now the dimension size is",
x->dims().size())); x->dims().size()));
bool is_inplaced = x->IsSharedBufferWith(*y);
auto src_tz = framework::vectorize<int64_t>(x->dims()); auto src_tz = framework::vectorize<int64_t>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(), src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("X")); ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
x->IsSharedBufferWith(*y) ? src_memory_p : handler.AcquireDstMemory(y);
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine()); mkldnn::stream astream(dev_ctx.GetEngine());
......
...@@ -48,13 +48,17 @@ class SoftmaxMKLDNNHandler ...@@ -48,13 +48,17 @@ class SoftmaxMKLDNNHandler
const mkldnn::engine mkldnn_engine, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input, platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis, Tensor* output, const int axis,
const std::string uniq_name) const std::string uniq_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique // Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), is_inplaced ? platform::CreateKey(
axis, uniq_name)) { dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims(), output->dims(), input->dims(), output->dims(),
...@@ -78,7 +82,7 @@ class SoftmaxMKLDNNHandler ...@@ -78,7 +82,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, axis, uniq_name)) { platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto data_softmax_md = auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = auto diff_softmax_md =
...@@ -98,17 +102,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -98,17 +102,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
bool is_inplaced = input->IsSharedBufferWith(*output);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size()); const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(), SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, output, axis, ctx.OutputName("Out")); input, output, axis, ctx.OutputName("Out"),
is_inplaced);
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output) auto softmax_dst_memory_p =
? softmax_src_memory_p is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
: handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive(); auto softmax_p = handler.AcquireForwardPrimitive();
......
...@@ -153,7 +153,18 @@ TEST(test_softmax_inplace_cache, cpu_place) { ...@@ -153,7 +153,18 @@ TEST(test_softmax_inplace_cache, cpu_place) {
CacheTester ct; CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out"); RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out", true); RunOperator<float>(p, "softmax", dims, "softmax_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(4), true, PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_relu_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "relu", dims, "relu_out");
RunOperator<float>(p, "relu", dims, "relu_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects")); "Wrong number of cached oneDNN objects"));
} }
......
...@@ -614,12 +614,15 @@ class ActivationMKLDNNHandler ...@@ -614,12 +614,15 @@ class ActivationMKLDNNHandler
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, platform::Place cpu_place,
const std::string& unique_name) const std::string& unique_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { is_inplaced
? platform::CreateKey(dev_ctx, dims, "a", algorithm,
unique_name)
: platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
...@@ -637,7 +640,7 @@ class ActivationMKLDNNHandler ...@@ -637,7 +640,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = auto src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册