From 985f2a4a019e438f2ac4edfc10683e781f2f1ea5 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 26 Aug 2022 13:15:52 +0800 Subject: [PATCH] Transfer transfer_layout from fluid to phi (#45261) * remove fluid kernel and activate phi kernel * fix parameter error * transfer mkldnn part * modify header file path * fix compile error * transfer special case * fix lod setting and special case for layout setting * add testcase and refine code --- .../framework/new_executor/data_transfer.cc | 20 +++ paddle/fluid/operators/transfer_layout_op.cc | 42 ++---- paddle/phi/infermeta/unary.cc | 6 +- paddle/phi/infermeta/unary.h | 3 +- paddle/phi/kernels/CMakeLists.txt | 3 +- paddle/phi/kernels/funcs/CMakeLists.txt | 5 + .../kernels/funcs/data_layout_transform.cc | 123 ++++++++++++++++ .../phi/kernels/funcs/data_layout_transform.h | 75 ++++++++++ .../phi/kernels/funcs/onednn/mkldnn_helper.h | 134 ++++++++++++++++++ .../phi/kernels/funcs/onednn/mkldnn_reuse.h | 130 +++++++++++++++-- paddle/phi/kernels/transfer_layout_kernel.cc | 114 ++++++++++++++- paddle/phi/kernels/transfer_layout_kernel.h | 9 +- paddle/phi/tests/kernels/CMakeLists.txt | 5 + .../kernels/test_transfer_layout_dev_api.cc | 74 ++++++++++ 14 files changed, 683 insertions(+), 60 deletions(-) create mode 100644 paddle/phi/kernels/funcs/data_layout_transform.cc create mode 100644 paddle/phi/kernels/funcs/data_layout_transform.h create mode 100644 paddle/phi/kernels/funcs/onednn/mkldnn_helper.h create mode 100644 paddle/phi/tests/kernels/test_transfer_layout_dev_api.cc diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 32277ed54bb..6d0f7911945 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -18,6 +18,10 @@ #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/phi/backends/onednn/onednn_context.h" +#endif + namespace paddle { namespace framework { namespace interpreter { @@ -200,11 +204,27 @@ std::shared_ptr TransferLayout(const std::string& var_name, framework::Scope* local_scope, bool is_fetch_v2) { #ifdef PADDLE_WITH_MKLDNN + // NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in fetch_op.cc if (in_layout == framework::DataLayout::kMKLDNN && var_name == framework::GradVarName("Filter") && is_fetch_v2) { + VLOG(4) << "Match special case(Filter && fetch_v2) " << var_name; out_layout = framework::DataLayout::kNCHW; } + + if (in_layout == framework::DataLayout::MKLDNN && + out_layout != framework::DataLayout::MKLDNN) { + auto target_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + VLOG(4) << "TransDataLayoutFromMKLDNN: " << in_layout << "->" + << target_layout; + + if (out_layout == DataLayout::kNCHW && + var_name == framework::GradVarName("Filter")) { + VLOG(4) << "Match special case(Filter) " << var_name; + target_layout = out_layout; + } + out_layout = target_layout; + } #endif // 1. Generate new_var_name and Initialize it diff --git a/paddle/fluid/operators/transfer_layout_op.cc b/paddle/fluid/operators/transfer_layout_op.cc index 1e1ff3494d0..86862d4a10f 100644 --- a/paddle/fluid/operators/transfer_layout_op.cc +++ b/paddle/fluid/operators/transfer_layout_op.cc @@ -16,7 +16,11 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace framework { @@ -37,34 +41,6 @@ class TransferLayoutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "TransferLayout"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TransferLayout"); - - auto dst_layout = ctx->Attrs().Get("dst_layout"); - auto low_bound = static_cast(framework::DataLayout::kAnyLayout); - auto upper_bound = static_cast(framework::DataLayout::kMKLDNN); - PADDLE_ENFORCE_GE( - dst_layout, - low_bound, - platform::errors::PreconditionNotMet( - "Required dst_layout >= %d, but received dst_layout = %d", - low_bound, - dst_layout)); - PADDLE_ENFORCE_LE( - dst_layout, - upper_bound, - platform::errors::PreconditionNotMet( - "Required dst_layout <= %d, but received dst_layout = %d", - upper_bound, - dst_layout)); - - // TODO(Aurelius84): Out's ddim is different with X because they have - // different layout - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -142,18 +118,18 @@ class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(transfer_layout, + TransferLayoutInferShapeFunctor, + PD_INFER_META(phi::TransferLayoutInferMeta)); REGISTER_OPERATOR( transfer_layout, ops::TransferLayoutOp, ops::TransferLayoutOpProtoMaker, ops::TransferLayoutInferVarType, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + TransferLayoutInferShapeFunctor); -// dtype is not important -REGISTER_OP_CPU_KERNEL_FUNCTOR(transfer_layout, - float, - ops::TransferLayoutKernel); REGISTER_OP_VERSION(transfer_layout) .AddCheckpoint(R"ROC(refine transfer_layout, add src_layout attribute)ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 76142c4eea1..9925a10e6dc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3649,11 +3649,13 @@ void TraceInferMeta( } void TransferLayoutInferMeta(const MetaTensor& x, - DataLayout layout, + int src_layout, + int dst_layout, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(x.dtype()); - out->set_layout(layout); + out->set_layout(static_cast(dst_layout)); + out->share_lod(x); } void TransposeInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 736360e7400..c0e20714a7b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -519,7 +519,8 @@ void TraceInferMeta( const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out); void TransferLayoutInferMeta(const MetaTensor& x, - DataLayout layout, + int src_layout, + int dst_layout, MetaTensor* out); void TransposeInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index d40b6f589c5..47ae390fb6f 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -66,7 +66,8 @@ set(COMMON_KERNEL_DEPS phi_dynload_warpctc sequence_padding sequence_scale - fft) + fft + phi_data_layout_transform) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index e21bea2e242..122e4ba7fea 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -17,6 +17,11 @@ math_library(segment_pooling) math_library(sequence2batch) math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) +cc_library( + phi_data_layout_transform + SRCS data_layout_transform.cc + DEPS tensor) + if(WITH_GPU OR WITH_ROCM) if(MKL_FOUND AND WITH_ONEMKL) math_library(fft spectral_op.cu DEPS dynload_cuda dynload_mklrt diff --git a/paddle/phi/kernels/funcs/data_layout_transform.cc b/paddle/phi/kernels/funcs/data_layout_transform.cc new file mode 100644 index 00000000000..800d67583e0 --- /dev/null +++ b/paddle/phi/kernels/funcs/data_layout_transform.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/data_layout_transform.h" + +#include "glog/logging.h" + +#include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" + +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/phi/kernels/funcs/onednn/mkldnn_helper.h" +#include "paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h" +#endif + +namespace phi { +namespace funcs { + +#ifdef PADDLE_WITH_MKLDNN + +void* GetDataFromTensor(const DenseTensor& tensor, + dnnl::memory::data_type type) { + switch (type) { + case dnnl::memory::data_type::f32: + return to_void_cast(tensor.data()); + case dnnl::memory::data_type::s8: + return to_void_cast(tensor.data()); + case dnnl::memory::data_type::u8: + return to_void_cast(tensor.data()); + case dnnl::memory::data_type::s32: + return to_void_cast(tensor.data()); + case dnnl::memory::data_type::bf16: + return to_void_cast(tensor.data()); + default: + PADDLE_THROW(errors::InvalidArgument("Wrong mkldnn type provided.")); + } +} + +void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, + DataLayout out_layout, + const DenseTensor& in, + DenseTensor* out, + Place place, + bool always_copy) { + // Set default as NCHW in case not specified + out_layout = out_layout == DataLayout::ANY ? DataLayout::NCHW : out_layout; + + auto& pool = DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast(pool.Get(place)); + auto& cpu_engine = dev_ctx->GetEngine(); + + auto in_tz = vectorize(in.dims()); + auto out_tz = in_tz; + + auto in_type = ToMKLDNNDataType(in.dtype()); + PADDLE_ENFORCE_NE( + in_type, + MKLDNNDataType::undef, + errors::InvalidArgument("Input tensor type (%s) is not supported.", + in.dtype())); + + auto out_format = + MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); + dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format); + + // output tensor has the same dims as input. Reorder don't change dims + out->set_mem_desc(out_mem_desc); + out->Resize(in.dims()); + + if ((in.mem_desc() != out->mem_desc()) || always_copy) { + void* in_data = GetDataFromTensor(in, in_type); + + ReorderMKLDNNHandler handler(in_tz, in.dtype(), in_type, cpu_engine); + + auto reorder_src_memory_p = + handler.AcquireSrcMemory(in.mem_desc(), in_data); + auto reorder_dst_memory_p = + handler.AcquireDstMemory(out, out->mem_desc(), place); + auto reorder_p = + handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); + + auto& astream = OneDNNContext::tls().get_stream(); + ::paddle::platform::RecordEvent record_reorder( + "ext_reorder", + ::paddle::platform::TracerEventType::UserDefined, + 2, + ::paddle::platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + } else { + out->ShareDataWith(in); + } + // For exepected NHWC data format we need to reshape the Output tensor + // As MKL-DNN description was in NCHW and paddle is expecting NHWC + MatchShapeToLayout(out, in_layout, out_layout); + + out->set_layout(DataLayout::kNCHW); + VLOG(10) << "out->layout: " << out->layout() << " in->dims: " << in.dims() + << " out->dims: " << out->dims(); + // reset format since the out tensor will be feed to non-MKLDNN OPkernel + out->set_format(MKLDNNMemoryFormat::undef); +} + +#endif + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/data_layout_transform.h b/paddle/phi/kernels/funcs/data_layout_transform.h new file mode 100644 index 00000000000..8fff3195b5c --- /dev/null +++ b/paddle/phi/kernels/funcs/data_layout_transform.h @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_MKLDNN +#include "dnnl.hpp" // NOLINT +#endif + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace funcs { + +#ifdef PADDLE_WITH_MKLDNN + +using MKLDNNDataType = dnnl::memory::data_type; +using MKLDNNMemoryFormat = dnnl::memory::format_tag; + +inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) { + switch (layout) { + case DataLayout::NHWC: + return MKLDNNMemoryFormat::nhwc; + case DataLayout::NCHW: + return MKLDNNMemoryFormat::nchw; + case DataLayout::NCDHW: + return MKLDNNMemoryFormat::ncdhw; + case DataLayout::NDHWC: + return MKLDNNMemoryFormat::ndhwc; + default: + PADDLE_THROW(errors::InvalidArgument( + "Fail to convert layout %s to MKLDNN format.", + ::paddle::framework::DataLayoutToString(layout))); + } +} + +// Caution: proto::VarType::Type -> phi::DataType after transfer +inline MKLDNNDataType ToMKLDNNDataType(DataType type) { + static std::unordered_map dict{ + {DataType::FLOAT32, MKLDNNDataType::f32}, + {DataType::INT8, MKLDNNDataType::s8}, + {DataType::UINT8, MKLDNNDataType::u8}, + {DataType::INT32, MKLDNNDataType::s32}, + {DataType::BFLOAT16, MKLDNNDataType::bf16}}; + auto iter = dict.find(type); + if (iter != dict.end()) return iter->second; + return MKLDNNDataType::undef; +} + +void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, + DataLayout out_layout, + const DenseTensor& in, + DenseTensor* out, + Place place, + bool always_copy = false); +void* GetDataFromTensor(const DenseTensor& tensor, MKLDNNDataType type); + +#endif + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/onednn/mkldnn_helper.h b/paddle/phi/kernels/funcs/onednn/mkldnn_helper.h new file mode 100644 index 00000000000..9a0aa8194c4 --- /dev/null +++ b/paddle/phi/kernels/funcs/onednn/mkldnn_helper.h @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "dnnl.hpp" // NOLINT +#include "glog/logging.h" + +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace funcs { + +using MKLDNNMemoryFormat = dnnl::memory::format_tag; +using MKLDNNDataType = dnnl::memory::data_type; + +template +void* to_void_cast(const Type* t) { + return static_cast(const_cast(t)); +} + +inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, + MKLDNNMemoryFormat data_format) { + if (dims_size == 1) { + return MKLDNNMemoryFormat::x; + } else if (dims_size == 2) { + return MKLDNNMemoryFormat::nc; + } else if (dims_size == 3) { + if (data_format == MKLDNNMemoryFormat::nchw) { + return MKLDNNMemoryFormat::ncw; + } else if (data_format == MKLDNNMemoryFormat::nhwc) { + return MKLDNNMemoryFormat::nwc; + } + } else if (dims_size == 4) { + if (data_format == MKLDNNMemoryFormat::goihw) { + return MKLDNNMemoryFormat::oihw; + } + } else if (dims_size == 5) { + if (data_format == MKLDNNMemoryFormat::goidhw) { + return MKLDNNMemoryFormat::oidhw; + } + if (data_format == MKLDNNMemoryFormat::nchw) { + return MKLDNNMemoryFormat::ncdhw; + } else if (data_format == MKLDNNMemoryFormat::nhwc) { + return MKLDNNMemoryFormat::ndhwc; + } + } else if (dims_size == 6) { + if (data_format == MKLDNNMemoryFormat::nchw) { + return MKLDNNMemoryFormat::abcdef; + } + } + return data_format; +} + +inline void MatchShapeToLayout(DenseTensor* tensor_in, + DataLayout from, + DataLayout to) { + auto print_dims = [](const std::vector& dims) { + std::ostringstream oss; + + if (!dims.empty()) { + oss << "["; + // Convert all but the last element to avoid a trailing "," + std::copy( + dims.begin(), dims.end() - 1, std::ostream_iterator(oss, ",")); + + // Now add the last element with no delimiter + oss << dims.back() << "]"; + } + + return oss.str(); + }; + + // In these data layouts, channel dimension is either on 2nd position: nChw or + // at last nhwC, so for dim==2 these layouts are the same and nothing should + // be done. Similarly for dim==1 when you have just one possible combination. + if (tensor_in->dims().size() < 3) { + VLOG(3) << "Keeping MKLDNN/NHWC/NDHWC output_shape" + << print_dims(phi::vectorize(tensor_in->dims())); + return; + } + + switch (from) { + case DataLayout::MKLDNN: + if ((to == DataLayout::NHWC) || (to == DataLayout::NDHWC)) { + auto dims = phi::vectorize(tensor_in->dims()); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + tensor_in->Resize(phi::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: MKLDNN to: NHWC/NDHWC output_shape" + << print_dims(dims); + } + break; + case DataLayout::NHWC: + case DataLayout::NDHWC: + if (to == DataLayout::MKLDNN) { + auto dims = phi::vectorize(tensor_in->dims()); + std::rotate(dims.begin() + 1, dims.end() - 1, dims.end()); + tensor_in->Resize(phi::make_ddim(dims)); + VLOG(3) << "Rotating Shape from: NHWC/NDHWC to: MKLDNN output_shape" + << print_dims(dims); + } + break; + default: + break; + } +} + +struct mkldnn_dummy_primitive { + struct primitive_desc {}; + struct desc {}; +}; + +inline dnnl::memory::desc MKLDNNMemDesc(const std::vector& dims, + dnnl::memory::data_type data_type, + MKLDNNMemoryFormat format) { + return dnnl::memory::desc({dims}, data_type, format); +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h b/paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h index 96333132508..56f2da3b3bd 100644 --- a/paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h +++ b/paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h @@ -20,11 +20,12 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/data_layout_transform.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/onednn/mkldnn_helper.h" namespace phi { namespace funcs { @@ -33,10 +34,12 @@ using user_function = std::function(const float*)>; using memory = dnnl::memory; using Place = phi::Place; +using MKLDNNMemoryFormat = dnnl::memory::format_tag; + template + typename TBackward = mkldnn_dummy_primitive, + typename TBackward_params = mkldnn_dummy_primitive> class MKLDNNHandlerNoCachingT { public: MKLDNNHandlerNoCachingT(dnnl::engine engine, Place cpu_place) @@ -62,8 +65,8 @@ class MKLDNNHandlerNoCachingT { std::shared_ptr AcquireSrcMemory(const DenseTensor* input) { const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive( - fwd_pd_->src_desc(), paddle::platform::to_void_cast(input_data)); + return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(), + to_void_cast(input_data)); } template @@ -81,16 +84,15 @@ class MKLDNNHandlerNoCachingT { template std::shared_ptr AcquireDstMemory(const DenseTensor* output) { const T_out* output_data = output->data(); - return this->AcquireMemoryFromPrimitive( - bwd_pd_->dst_desc(), - paddle::platform::to_void_cast(output_data)); + return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(), + to_void_cast(output_data)); } std::shared_ptr AcquireDiffDstMemory( const DenseTensor* diffdst) { const T* ptr = diffdst->data(); - return this->AcquireMemoryFromPrimitive( - bwd_pd_->diff_dst_desc(), paddle::platform::to_void_cast(ptr)); + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(), + to_void_cast(ptr)); } std::shared_ptr AcquireDiffSrcMemory(DenseTensor* diffsrc) { @@ -291,10 +293,110 @@ class ActivationMKLDNNHandler std::shared_ptr AcquireBackwardSrcMemory( const DenseTensor* input) { const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive( - this->bwd_pd_->src_desc(), - paddle::platform::to_void_cast(input_data)); + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), + to_void_cast(input_data)); + } +}; + +class ReorderMKLDNNHandler { + public: + ReorderMKLDNNHandler(std::vector& dims, // NOLINT + DataType ptype, + dnnl::memory::data_type dtype, + dnnl::engine engine) + : dims_(dims), + ptype_(ptype), + ptype_dst_(ptype), + dtype_(dtype), + dtype_dst_(dtype), + engine_(engine) {} + + ReorderMKLDNNHandler(std::vector& dims, // NOLINT + DataType ptype, + dnnl::memory::data_type dtype, + DataType ptype_dst, + dnnl::memory::data_type dtype_dst, + dnnl::engine engine) + : dims_(dims), + ptype_(ptype), + ptype_dst_(ptype_dst), + dtype_(dtype), + dtype_dst_(dtype_dst), + engine_(engine) {} + + std::shared_ptr AcquireSrcMemory(const dnnl::memory::desc& md, + void* ptr) { + return std::make_shared(md, engine_, ptr); + } + + std::shared_ptr AcquireSrcMemory(const MKLDNNMemoryFormat& fmt, + void* ptr) { + auto md = dnnl::memory::desc(dims_, dtype_, fmt); + return std::make_shared(md, engine_, ptr); + } + + std::shared_ptr AcquireSubmemory( + const std::vector& dims, + const std::vector& offset, + const std::shared_ptr& mem_p) { + auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); + auto sub_mem_p = std::make_shared( + sub_md, engine_, mem_p->get_data_handle()); + return sub_mem_p; + } + + std::shared_ptr AcquireDstMemory(DenseTensor* output, + const MKLDNNMemoryFormat& fmt, + Place place) { + auto dst_md = MKLDNNMemDesc(dims_, dtype_dst_, fmt); + auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); + return std::make_shared(dst_md, engine_, dst_data); } + + std::shared_ptr AcquireDstMemory( + DenseTensor* output, const dnnl::memory::desc& src_md, Place place) { + if (ptype_dst_ == ptype_) { + auto dst_data = + output->mutable_data(place, ptype_dst_, src_md.get_size()); + return std::make_shared(src_md, engine_, dst_data); + } else { + auto dst_md = src_md; + dst_md.data.data_type = static_cast(dtype_dst_); + auto dst_data = + output->mutable_data(place, ptype_dst_, dst_md.get_size()); + return std::make_shared(dst_md, engine_, dst_data); + } + } + + std::shared_ptr AcquireDstMemory( + DenseTensor* output, + const std::vector& dims, + const MKLDNNMemoryFormat& fmt, + Place place) { + auto dst_md = MKLDNNMemDesc(dims, dtype_dst_, fmt); + auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); + return std::make_shared(dst_md, engine_, dst_data); + } + + std::shared_ptr AcquireReorder( + std::shared_ptr dst_memory_p, + std::shared_ptr src_memory_p) { + return std::make_shared(*(src_memory_p), *(dst_memory_p)); + } + + std::shared_ptr AcquireReorder( + std::shared_ptr dst_memory_p, + std::shared_ptr src_memory_p, + const dnnl::primitive_attr& attrs) { + return std::make_shared( + *(src_memory_p), *(dst_memory_p), attrs); + } + + private: + std::vector dims_; + DataType ptype_, ptype_dst_; + dnnl::memory::data_type dtype_, dtype_dst_; + dnnl::engine engine_; }; } // namespace funcs diff --git a/paddle/phi/kernels/transfer_layout_kernel.cc b/paddle/phi/kernels/transfer_layout_kernel.cc index f7ecf379fdf..2110a06f161 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.cc +++ b/paddle/phi/kernels/transfer_layout_kernel.cc @@ -14,11 +14,17 @@ limitations under the License. */ #include "paddle/phi/kernels/transfer_layout_kernel.h" +#include +#include + #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" #include "paddle/phi/kernels/funcs/math_function.h" - +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/phi/kernels/funcs/onednn/mkldnn_helper.h" +#endif namespace phi { std::vector GetAxis(const DataLayout& from, const DataLayout& to) { @@ -46,10 +52,10 @@ void CastDataLayout(const Context& dev_ctx, } template -void TransferLayoutKernel(const Context& dev_ctx, - const DenseTensor& x, - DataLayout dst_layout, - DenseTensor* out) { +void TransferLayoutGeneral(const Context& dev_ctx, + const DenseTensor& x, + DataLayout dst_layout, + DenseTensor* out) { auto src_dim = x.dims(); auto axis = GetAxis(x.layout(), dst_layout); @@ -60,16 +66,110 @@ void TransferLayoutKernel(const Context& dev_ctx, dst_dim[i] = src_dim[axis[i]]; } - out->ResizeAndAllocate(phi::make_ddim(dst_dim)); + out->Resize(phi::make_ddim(dst_dim)); + dev_ctx.Alloc(out, x.dtype()); PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] { CastDataLayout(dev_ctx, x, axis, out); })); } +#ifdef PADDLE_WITH_MKLDNN +template +void TransferLayoutMKLDNN(const Context& dev_ctx, + const DenseTensor& x, + DataLayout src_layout, + DataLayout dst_layout, + DenseTensor* out) { + auto print_tensor_meta = [](const DenseTensor& x) { + std::ostringstream oss; + + oss << "["; + oss << "layout:" << x.layout() << " ,"; + oss << "dims:" << x.dims() << " ,"; + if (x.IsInitialized()) oss << "place:" << x.place(); + oss << "]"; + + return oss.str(); + }; + VLOG(10) << " x: " << print_tensor_meta(x); + VLOG(10) << " out: " << print_tensor_meta(*out) << " " << out; + + // NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in + // data_transfer.cc + if (!x.IsInitialized() && src_layout == DataLayout::MKLDNN && + dst_layout == DataLayout::NHWC) { + VLOG(4) << src_layout << "->" << dst_layout << " " << x.layout(); + out->Resize(x.dims()); + out->set_layout(dst_layout); + funcs::MatchShapeToLayout(out, src_layout, dst_layout); + return; + } + + if (src_layout != DataLayout::MKLDNN && dst_layout == DataLayout::MKLDNN) { + // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel + // Just set layout/format. No real transform occur + auto out_format = funcs::MKLDNNFormatForSize( + x.dims().size(), funcs::ToMKLDNNFormat(src_layout)); + + out->ShareDataWith(x); + // For NHWC data we need reshape of tensors as MKL-DNN + // is expecting NHWC dims description order + if (src_layout == DataLayout::NHWC) { + VLOG(4) << "NHWC"; + funcs::MatchShapeToLayout(out, src_layout, dst_layout); + OneDNNContext::tls().set_cur_paddle_data_layout(src_layout); + } + + out->set_layout(DataLayout::MKLDNN); + out->set_format(out_format); + } else if (src_layout == DataLayout::MKLDNN && + dst_layout != DataLayout::MKLDNN) { + // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel + // Do transform via MKLDNN lib + funcs::innerTransDataLayoutFromMKLDNN( + src_layout, dst_layout, x, out, dev_ctx.GetPlace()); + } else if (src_layout == DataLayout::MKLDNN && + dst_layout == DataLayout::MKLDNN) { + PADDLE_ENFORCE_NE( + src_layout, + dst_layout, + errors::PreconditionNotMet( + "No layout transform needed between two MKLDNN OPKernels.")); + } else { + TransferLayoutGeneral(dev_ctx, x, dst_layout, out); + } +} +#endif + +template +void TransferLayoutKernel(const Context& dev_ctx, + const DenseTensor& x, + int src_layout, + int dst_layout, + DenseTensor* out) { + PADDLE_ENFORCE_NE(src_layout, + dst_layout, + errors::PreconditionNotMet( + "No layout transform needed between same layout.")); + VLOG(10) << "TransDataLayout from " << static_cast(src_layout) + << " -> " << static_cast(dst_layout); + +#ifdef PADDLE_WITH_MKLDNN + TransferLayoutMKLDNN(dev_ctx, + x, + static_cast(src_layout), + static_cast(dst_layout), + out); +#else + TransferLayoutGeneral( + dev_ctx, x, static_cast(dst_layout), out); +#endif +} + } // namespace phi -PD_REGISTER_GENERAL_KERNEL(phi_transfer_layout, +PD_REGISTER_GENERAL_KERNEL(transfer_layout, CPU, ALL_LAYOUT, phi::TransferLayoutKernel, diff --git a/paddle/phi/kernels/transfer_layout_kernel.h b/paddle/phi/kernels/transfer_layout_kernel.h index 3777daf07de..73e12927d7f 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.h +++ b/paddle/phi/kernels/transfer_layout_kernel.h @@ -23,7 +23,8 @@ namespace phi { template void TransferLayoutKernel(const Context& dev_ctx, const DenseTensor& x, - DataLayout dst_layout, + int src_layout, + int dst_layout, DenseTensor* out); template @@ -32,7 +33,11 @@ DenseTensor TransferLayout(const Context& dev_ctx, DataLayout dst_layout) { phi::DenseTensor dense_out = phi::Empty(dev_ctx, {x.dtype(), x.dims(), dst_layout}); - TransferLayoutKernel(dev_ctx, x, dst_layout, &dense_out); + TransferLayoutKernel(dev_ctx, + x, + static_cast(x.layout()), + static_cast(dst_layout), + &dense_out); return dense_out; } diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 152bc0dd0c0..d1c9d25483f 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -134,3 +134,8 @@ cc_test( test_memcpy_dev_api SRCS test_memcpy_dev_api.cc DEPS phi phi_api_utils) + +cc_test( + test_transfer_layout_dev_api + SRCS test_transfer_layout_dev_api.cc + DEPS phi phi_api_utils) diff --git a/paddle/phi/tests/kernels/test_transfer_layout_dev_api.cc b/paddle/phi/tests/kernels/test_transfer_layout_dev_api.cc new file mode 100644 index 00000000000..0c81ecada96 --- /dev/null +++ b/paddle/phi/tests/kernels/test_transfer_layout_dev_api.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/transfer_layout_kernel.h" + +namespace phi { +namespace tests { + +#ifdef PADDLE_WITH_MKLDNN +TEST(DEV_API, transfer_layout) { + // 1. create tensor + + const int n = 2; + const int c = 3; + const int h = 4; + const int w = 5; + + DenseTensor x; + MetaTensor meta_x(&x); + meta_x.set_dtype(DataType::FLOAT32); + meta_x.set_layout(DataLayout::MKLDNN); + meta_x.set_dims(make_ddim({n, c, h, w})); + + DenseTensor out; + + // 2. test API + auto& pool = phi::DeviceContextPool::Instance(); + auto place = phi::CPUPlace(); + auto* dev_ctx = static_cast(pool.GetByPlace(place)); + + MetaTensor meta_out(&out); + TransferLayoutInferMeta(x, + static_cast(x.layout()), + static_cast(DataLayout::NHWC), + &meta_out); + TransferLayoutKernel(*dev_ctx, + x, + static_cast(x.layout()), + static_cast(DataLayout::NHWC), + &out); + + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.dims(), make_ddim({n, h, w, c})); + ASSERT_EQ(out.dims().size(), 4); + ASSERT_EQ(out.meta().dtype, DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, DataLayout::NHWC); +} + +#endif +} // namespace tests +} // namespace phi -- GitLab