未验证 提交 dfdb0359 编写于 作者: J Jacek Czaja 提交者: GitHub

- Disabling oneDNN inplace pass (#30588)

上级 430f8449
......@@ -224,12 +224,11 @@ void CpuPassStrategy::EnableMKLDNN() {
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass",
#ifndef _WIN32
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
#endif
//"mkldnn_inplace_pass", // This pass should be activated after
// fuses. Disabled by default due to
// little gain and lots of problems
})) {
passes_.push_back(pass);
}
......
......@@ -99,17 +99,17 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
"5, or 6, but now the dimension size is",
x->dims().size()));
bool is_inplaced = x->IsSharedBufferWith(*y);
auto src_tz = framework::vectorize<int64_t>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("X"));
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p =
x->IsSharedBufferWith(*y) ? src_memory_p : handler.AcquireDstMemory(y);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
auto activation_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
......
......@@ -48,13 +48,17 @@ class SoftmaxMKLDNNHandler
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis,
const std::string uniq_name)
const std::string uniq_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)) {
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
......@@ -78,7 +82,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, axis, uniq_name)) {
platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
......@@ -98,17 +102,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
bool is_inplaced = input->IsSharedBufferWith(*output);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, output, axis, ctx.OutputName("Out"));
input, output, axis, ctx.OutputName("Out"),
is_inplaced);
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
? softmax_src_memory_p
: handler.AcquireDstMemory(output);
auto softmax_dst_memory_p =
is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive();
......
......@@ -153,7 +153,18 @@ TEST(test_softmax_inplace_cache, cpu_place) {
CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(4), true,
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_relu_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "relu", dims, "relu_out");
RunOperator<float>(p, "relu", dims, "relu_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
......
......@@ -614,12 +614,15 @@ class ActivationMKLDNNHandler
const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& unique_name)
const std::string& unique_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
is_inplaced
? platform::CreateKey(dev_ctx, dims, "a", algorithm,
unique_name)
: platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
......@@ -637,7 +640,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册