From 31f0221f02e2a64a83613473d00c9906b973ab5a Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 26 Aug 2021 17:26:46 +0200 Subject: [PATCH] [oneDNN] disable caching oneDNN primitives in matmul v2, Reduce grad and elementwise_add grad, expand_v2 (#35132) * - grad caching disabled of matmul_v1 - compilation fix - compilation fix * - reduction removed * - Matmul v2 disabled caching * Draft of further changes * - workaround for reducegrad * - fixes to UT * - fix to compilation * - another fix * - fix --- .../mkldnn/elementwise_add_mkldnn_op.cc | 6 +- .../mkldnn/elementwise_mul_mkldnn_op.cc | 6 +- .../operators/mkldnn/expand_v2_mkldnn_op.cc | 8 +- .../operators/mkldnn/matmul_mkldnn_op.cc | 106 +++++++-------- .../fluid/operators/mkldnn/matmul_mkldnn_op.h | 3 +- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 122 ++++++++---------- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 50 +++---- paddle/fluid/platform/mkldnn_reuse.h | 103 +++++++-------- 8 files changed, 179 insertions(+), 225 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 8f519de0757..6cea4bfb990 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { } else { // Broadcasting platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, - ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out")), - CalculateBroadcastedDims(dout, dy)); + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, + ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index af4aab80478..2acf1e0fcd7 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { // Reduction is needed for broadcasting scenario if (dout->dims() != dy->dims()) { platform::ReductionMKLDNNHandler handler_sum( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, - ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out")), - CalculateBroadcastedDims(dout, dy)); + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine, + ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); // As source we use mem object with results from binary operation diff --git a/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc index 3e630856409..829c948c1a5 100644 --- a/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc @@ -53,8 +53,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel { out->Resize(paddle::framework::make_ddim(out_new_dims)); out->set_format(x_format_tag); paddle::platform::BroadcastDataMKLDNNHandler handler( - dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(), - out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims); + dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x, + 0.0f, 1.0f, x_vec_dims); auto src_memory_p = handler.AcquireSrcMemory(x); auto dst_memory_p = handler.AcquireDstMemory(out); @@ -136,8 +136,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel { paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc())); } else { paddle::platform::ReductionMKLDNNHandler handler( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, - ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims); + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, + ctx.GetPlace(), dout, dx, dx_vec_dims); auto src_memory_p = handler.AcquireSrcMemory(dout); auto dst_memory_p = handler.AcquireDstMemory(dx); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 7ebd3e38560..35f93eba690 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -83,58 +83,52 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, template class MatMulMKLDNNHandler - : public paddle::platform::MKLDNNHandlerT { + : public paddle::platform::MKLDNNHandlerNoCachingT { public: - MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, + MatMulMKLDNNHandler(const mkldnn::engine engine, paddle::platform::Place cpu_place, Tensor* x, bool trans_x, Tensor* y, bool trans_y, Tensor* out, - float scale, const std::string& uniq_name) - : paddle::platform::MKLDNNHandlerT( - dev_ctx, engine, cpu_place, - paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()), - uniq_name)) { - if (!this->isCached()) { - auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor( - x->dims(), 0, trans_x); - auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor( - y->dims(), 0, trans_y); - - memory::dim x_bs = mat_dim_x.batch_size_; - memory::dim y_bs = mat_dim_y.batch_size_; - - memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; - const memory::dim M = mat_dim_x.height_; - const memory::dim N = mat_dim_y.width_; - const memory::dim K = mat_dim_x.width_; - - memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; - memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; - memory::dims out_dims = {out_bs, M, N}; - - memory::dims x_strides = - !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; - - memory::dims y_strides = - !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; - memory::dims out_strides = memory::dims{M * N, N, 1}; - - auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); - auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); - auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); - - dnnl::primitive_attr attrs; - if (scale != 1.0f) attrs.set_output_scales(0, {scale}); - - this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); - } + float scale) + : paddle::platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y); + + memory::dim x_bs = mat_dim_x.batch_size_; + memory::dim y_bs = mat_dim_y.batch_size_; + + memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; + const memory::dim M = mat_dim_x.height_; + const memory::dim N = mat_dim_y.width_; + const memory::dim K = mat_dim_x.width_; + + memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; + memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; + memory::dims out_dims = {out_bs, M, N}; + + memory::dims x_strides = + !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; + + memory::dims y_strides = + !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; + memory::dims out_strides = memory::dims{M * N, N, 1}; + + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); + + dnnl::primitive_attr attrs; + if (scale != 1.0f) attrs.set_output_scales(0, {scale}); + + this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); } std::shared_ptr AcquireWeightsMemory(const Tensor* input) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(input_data), - "@weights_mem_p"); + to_void_cast(input_data)); } }; @@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine& engine, Tensor* x, bool trans_x, bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y, - Tensor* out, int execution_number) const { + Tensor* out) const { // gradient is calculated in a different way when broadcasting is used bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && out->dims().size() == 2; @@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - MatMulMKLDNNHandler handler(dev_ctx, engine, ctx.GetPlace(), &x_combined, - trans_x, &y_combined, trans_y, out, alpha, - ctx.InputName(framework::GradVarName("Out")) + - std::to_string(execution_number)); + MatMulMKLDNNHandler handler(engine, ctx.GetPlace(), &x_combined, trans_x, + &y_combined, trans_y, out, alpha); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); @@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel::RunKernel(const ExecutionContext& ctx) const { if (transpose_x && transpose_y) { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout, - true, false, dx, 0); + true, false, dx); this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, - true, false, dy, 1); + true, false, dy); } else if (transpose_x) { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, - &dout, true, false, dx, 0); + &dout, true, false, dx); this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false, - &dout, false, true, dy, 1); + &dout, false, true, dy); } else if (transpose_y) { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, - &y, false, true, dx, 0); + &y, false, true, dx); this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, - false, true, dy, 1); + false, true, dy); } else { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, - &y, true, false, dx, 0); + &y, true, false, dx); this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout, - false, true, dy, 1); + false, true, dy); } if (dx) { diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h index 725d1fff9c6..69ae78fcca0 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h @@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel { const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine& engine, Tensor* x, bool trans_x, bool is_fold_init_dims_x, Tensor* y, bool trans_y, - bool is_fold_init_dims_y, Tensor* out, - int execution_number) const; + bool is_fold_init_dims_y, Tensor* out) const; void RunKernel(const ExecutionContext& ctx) const; }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index b5dc096441c..57a3c385593 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -31,79 +31,72 @@ using paddle::framework::GradVarName; template class MatMulV2MKLDNNHandler - : public paddle::platform::MKLDNNHandlerT { + : public paddle::platform::MKLDNNHandlerNoCachingT { public: - MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, + MatMulV2MKLDNNHandler(const mkldnn::engine engine, paddle::platform::Place cpu_place, const std::vector& x_org_dims, bool trans_x, - const std::vector& y_org_dims, bool trans_y, - const std::string& uniq_name) - : paddle::platform::MKLDNNHandlerT( - dev_ctx, engine, cpu_place, - paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) { - if (!this->isCached()) { - // M X K * K X N - std::vector x_dims(x_org_dims); - std::vector y_dims(y_org_dims); - - const int MB_idx = x_dims.size() - 3; - const int H_idx = x_dims.size() - 2; - const int W_idx = x_dims.size() - 1; - - if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); - if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); - - const memory::dim M = x_dims[H_idx]; - const memory::dim K = x_dims[W_idx]; - const memory::dim N = y_dims[W_idx]; - - std::vector x_strides(x_dims.size() - 3, 1); - std::vector y_strides(x_dims.size() - 3, 1); - std::vector out_strides(x_dims.size() - 3, 1); - std::vector out_ddims(x_dims.size() - 3, 1); - - x_strides.reserve(x_dims.size()); - y_strides.reserve(x_dims.size()); - out_strides.reserve(x_dims.size()); - - if (!trans_x) { - x_strides.insert(x_strides.end(), {M * K, K, 1}); - } else { - x_strides.insert(x_strides.end(), {M * K, 1, M}); - } + const std::vector& y_org_dims, bool trans_y) + : paddle::platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { + // M X K * K X N + std::vector x_dims(x_org_dims); + std::vector y_dims(y_org_dims); + + const int MB_idx = x_dims.size() - 3; + const int H_idx = x_dims.size() - 2; + const int W_idx = x_dims.size() - 1; + + if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); + if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); + + const memory::dim M = x_dims[H_idx]; + const memory::dim K = x_dims[W_idx]; + const memory::dim N = y_dims[W_idx]; + + std::vector x_strides(x_dims.size() - 3, 1); + std::vector y_strides(x_dims.size() - 3, 1); + std::vector out_strides(x_dims.size() - 3, 1); + std::vector out_ddims(x_dims.size() - 3, 1); + + x_strides.reserve(x_dims.size()); + y_strides.reserve(x_dims.size()); + out_strides.reserve(x_dims.size()); + + if (!trans_x) { + x_strides.insert(x_strides.end(), {M * K, K, 1}); + } else { + x_strides.insert(x_strides.end(), {M * K, 1, M}); + } - if (!trans_y) { - y_strides.insert(y_strides.end(), {N * K, N, 1}); - } else { - y_strides.insert(y_strides.end(), {N * K, 1, K}); - } + if (!trans_y) { + y_strides.insert(y_strides.end(), {N * K, N, 1}); + } else { + y_strides.insert(y_strides.end(), {N * K, 1, K}); + } - out_strides.insert(out_strides.end(), {M * N, N, 1}); - out_ddims.insert(out_ddims.end(), - {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); + out_strides.insert(out_strides.end(), {M * N, N, 1}); + out_ddims.insert(out_ddims.end(), + {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); - for (int i = x_dims.size() - 4; i >= 0; --i) { - out_ddims[i] = std::max(x_dims[i], y_dims[i]); - x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; - y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; - out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; - } + for (int i = x_dims.size() - 4; i >= 0; --i) { + out_ddims[i] = std::max(x_dims[i], y_dims[i]); + x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; + y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; + out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; + } - auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); - auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); - auto out_md = - memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); - this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); - } + this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); } std::shared_ptr AcquireWeightsMemory(const Tensor* input) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(input_data), - "@weights_mem_p"); + to_void_cast(input_data)); } }; @@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel const Tensor* y, std::vector& y_dims, bool trans_y, Tensor* out, std::vector& out_dims, int execution_number = 0) const { - MatMulV2MKLDNNHandler handler( - dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims, - trans_y, ctx.InputName("X") + std::to_string(execution_number)); + MatMulV2MKLDNNHandler handler(onednn_engine, ctx.GetPlace(), x_dims, + trans_x, y_dims, trans_y); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); @@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { const Tensor* dx_tmp, Tensor* dx, std::vector dx_dims) const { paddle::platform::ReductionMKLDNNHandler handler( - dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, - ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims); + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, + ctx.GetPlace(), dx_tmp, dx, dx_dims); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto dst_memory_p = handler.AcquireDstMemory(dx); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 6a9aae046f3..0165cfd8b80 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -96,9 +96,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel { platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( paddle::framework::vectorize(output->dims())))); } else { - platform::ReductionMKLDNNHandler handler( - reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), - input, output, ctx.InputName("X"), output_dims); + platform::ReductionMKLDNNHandler handler(reduction_type, 0.0f, 0.0f, + onednn_engine, ctx.GetPlace(), + input, output, output_dims); auto src_memory_p = handler.AcquireSrcMemory(input); auto dst_memory_p = handler.AcquireDstMemory(output); @@ -137,40 +137,28 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { mkldnn::memory::format_tag x_format_tag; auto input_dims = CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim); + auto output_dims = framework::vectorize(output_dx->dims()); - if (input_dims != framework::vectorize(output_dx->dims())) { - const std::string key_pd = - platform::CreateKey( - dev_ctx, framework::vectorize(output_dx->dims()), - ctx.InputName("X"), - (std::to_string(static_cast(reduction_type)))) + - "@fwd_pd"; - std::shared_ptr fwd_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_pd)); - - PADDLE_ENFORCE_NOT_NULL( - fwd_pd, platform::errors::Unavailable( - "Forward primitive descriptor is not available in %s op, " - "cannot deduce memory format tag", - ctx.Type())); - - x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc()); - - PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef, - platform::errors::InvalidArgument( - "Cannot deduce format tag for %s op", ctx.Type())); - } else { // fwd descriptor not available because reorder was used instead - // of reduction + if (input_dims != output_dims) { + auto input_dy_md = dnnl::memory::desc( + framework::vectorize(input_dy->dims()), + platform::MKLDNNGetDataType(), input_dy->format()); + auto input_dy_ex_md = input_dy_md.reshape(input_dims); + // TODO(jczaja): once MD is stored in Tensor we no longer need to guess + // formats + x_format_tag = platform::GetMKLDNNFormat(input_dy_ex_md); + + } else { + // There was no broadcasting then just simple copy is done + // same format used for input and output x_format_tag = getPlainFormatTag(output_dx); } output_dx->set_format(x_format_tag); platform::BroadcastDataMKLDNNHandler handler( - binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, - input_dy, scale_x, scale_y, - ctx.InputName(framework::GradVarName("Out")), input_dims); + binary_type, onednn_engine, ctx.GetPlace(), output_dx, input_dy, + scale_x, scale_y, input_dims); const auto src_memory_p = handler.AcquireSrcMemory(input_dy); const auto dst_memory_p = handler.AcquireDstMemory(output_dx); @@ -184,6 +172,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); binary_prim->execute(astream, args); astream.wait(); + + output_dx->set_layout(framework::DataLayout::kMKLDNN); } protected: diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index c27bc6c6e55..e6442ded6b5 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -895,41 +895,34 @@ class BinaryMKLDNNHandler template class BroadcastDataMKLDNNHandler - : public platform::MKLDNNHandlerT { + : public platform::MKLDNNHandlerNoCachingT { public: BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, - const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* out, const Tensor* x, float scale_x, float scale_y, - const std::string& uniq_name, const std::vector& input_dims) - : platform::MKLDNNHandlerT( - dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), - uniq_name)) { - if (!this->isCached()) { - PADDLE_ENFORCE_EQ( - x->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for X tensor.")); - PADDLE_ENFORCE_NE( - x->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for X tensor.")); - - const auto src0_tz = framework::vectorize(out->dims()); - - const auto src0_md = dnnl::memory::desc( - src0_tz, platform::MKLDNNGetDataType(), out->format()); - const auto src1_md = dnnl::memory::desc( - input_dims, platform::MKLDNNGetDataType(), out->format()); - - dnnl::primitive_attr attributes; - attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); - attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y}); - - this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, - src1_md, src0_md); - } + : platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for X tensor.")); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor.")); + + const auto src0_tz = framework::vectorize(out->dims()); + + const auto src0_md = dnnl::memory::desc( + src0_tz, platform::MKLDNNGetDataType(), out->format()); + const auto src1_md = dnnl::memory::desc( + input_dims, platform::MKLDNNGetDataType(), out->format()); + + dnnl::primitive_attr attributes; + attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); + attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y}); + + this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md, + src0_md); } template @@ -938,43 +931,35 @@ class BroadcastDataMKLDNNHandler this->place_, this->fwd_pd_->dst_desc().get_size()); ; memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr, - "@dst_mem_p"); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } }; template class ReductionMKLDNNHandler - : public platform::MKLDNNHandlerT { + : public platform::MKLDNNHandlerNoCachingT { public: ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, - const float eps, const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, platform::Place cpu_place, - const Tensor* x, const Tensor* y, - const std::string& uniq_name, - std::vector y_tz) - : platform::MKLDNNHandlerT( - dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), - uniq_name, - (std::to_string(static_cast(algo))))) { - if (!this->isCached()) { - PADDLE_ENFORCE_EQ( - x->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for X tensor.")); - PADDLE_ENFORCE_NE( - x->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for X tensor.")); - - const auto x_tz = framework::vectorize(x->dims()); - - const auto x_md = dnnl::memory::desc( - x_tz, platform::MKLDNNGetDataType(), x->format()); - const auto y_md = - memory::desc(y_tz, platform::MKLDNNGetDataType(), x->format()); - - this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps); - } + const float eps, const mkldnn::engine engine, + platform::Place cpu_place, const Tensor* x, + const Tensor* y, std::vector y_tz) + : platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for X tensor.")); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor.")); + + const auto x_tz = framework::vectorize(x->dims()); + + const auto x_md = + dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType(), x->format()); + const auto y_md = + memory::desc(y_tz, platform::MKLDNNGetDataType(), x->format()); + + this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps); } }; -- GitLab