From b232b5e9a1b51b1b6550aaf5b0b65ea5531677bc Mon Sep 17 00:00:00 2001 From: Piotr Paturej <48731682+piotrekobi@users.noreply.github.com> Date: Tue, 20 Sep 2022 14:14:33 +0200 Subject: [PATCH] [PHI] Migrate slice, slice_grad, split, pad and pad3d oneDNN kernels (#46101) * Convert split, pad and pad3d kernels * Convert slice+grad oneDNN fluid kernels to PHI * change out->mutable_data to dev_ctx.Alloc --- .../fluid/operators/mkldnn/pad3d_mkldnn_op.cc | 223 ----------------- .../fluid/operators/mkldnn/slice_mkldnn_op.cc | 236 ------------------ .../fluid/operators/mkldnn/split_mkldnn_op.cc | 140 ----------- paddle/phi/backends/onednn/onednn_helper.h | 12 +- paddle/phi/backends/onednn/onednn_reuse.h | 8 +- paddle/phi/kernels/onednn/pad3d_kernel.cc | 34 +++ paddle/phi/kernels/onednn/pad_kernel.cc | 37 +++ paddle/phi/kernels/onednn/pad_kernel_impl.h | 177 +++++++++++++ .../phi/kernels/onednn/slice_grad_kernel.cc | 86 +++++++ paddle/phi/kernels/onednn/slice_kernel.cc | 109 ++++++++ paddle/phi/kernels/onednn/split_kernel.cc | 90 +++++++ 11 files changed, 543 insertions(+), 609 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc delete mode 100644 paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc delete mode 100644 paddle/fluid/operators/mkldnn/split_mkldnn_op.cc create mode 100644 paddle/phi/kernels/onednn/pad3d_kernel.cc create mode 100644 paddle/phi/kernels/onednn/pad_kernel.cc create mode 100644 paddle/phi/kernels/onednn/pad_kernel_impl.h create mode 100644 paddle/phi/kernels/onednn/slice_grad_kernel.cc create mode 100644 paddle/phi/kernels/onednn/slice_kernel.cc create mode 100644 paddle/phi/kernels/onednn/split_kernel.cc diff --git a/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc deleted file mode 100644 index e7a528c452..0000000000 --- a/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc +++ /dev/null @@ -1,223 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -namespace paddle { -namespace operators { - -using framework::Tensor; - -/* -Pad3D is done by using up to 7 reorders. Following example is done -on 2D data for simplicity, but it is straightforward to extend it to 3D case. - -Let us consider following example: - - N C H W L R T B -X_dims = (1, 1, 3, 3), paddings = (1, 2, 3, 4) in order Left, Right, Top, Bottom - -We have to copy the X tensor into Out tensor, but except from that we have to -fill the rest of the memory with an additional padding. To avoid looping through -the whole Out memory two times, only these parts of Out memory that won't store -X's memory are filled with pad value. That behavior is achieved by using -oneDNN's submemory descriptors which allows us to set offsets for each dimension -and skip some parts of the memory. For 2D case up to 5 reorders will be used in -Pad3D kernel(if padding=0 reorder is skipped). In the following example i'th -number means, that this part of memory was filled by i'th reorder. 4'th reorder -is copying X memory into Out memory. i&j means that both i'th and j'th reorder -will set the padding at that location: - - INDEX - | 0 1 2 3 4 5 - |_______________________ - 0 |0&2 2 2 2 1&2 1&2 - 1 |0&2 2 2 2 1&2 1&2 -I 2 |0&2 2 2 2 1&2 1&2 -N 3 | 0 4 4 4 1 1 -D 4 | 0 4 4 4 1 1 -E 5 | 0 4 4 4 1 1 -X 6 |0&3 3 3 3 1&3 1&3 - 7 |0&3 3 3 3 1&3 1&3 - 8 |0&3 3 3 3 1&3 1&3 - 9 |0&3 3 3 3 1&3 1&3 - -Since oneDNN's reorder cannot set the pad value to the memory by itself, we have -to prefill Out's memory and use it as a temporary buffer, which later is copied -into the rest of Out's memory. At the end last reorder is done which copies X -memory into Out memory. - -*/ -template -class PadMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - auto* paddings_tensor = ctx.Input("Paddings"); - std::vector paddings(ctx.Attr>("paddings")); - if (paddings_tensor) { - std::copy(paddings_tensor->data(), - paddings_tensor->data() + paddings_tensor->numel(), - paddings.data()); - } - // pad2d has paddings in order top, bottom, left, right, so we need - // to swap some of them to unify paddings between pad2d and pad3d - if (ctx.Type() == "pad2d") { - std::swap(paddings[0], paddings[2]); - std::swap(paddings[1], paddings[3]); - } - - const std::string pad_attr_name = - ctx.Type() == "pad3d" ? "value" : "pad_value"; - T pad_value = static_cast(ctx.Attr(pad_attr_name)); - - std::vector x_tz = phi::vectorize(x->dims()); - // due to the need of supporting NDHWC, inferring out shape - // must be done inside the kernel - std::vector out_tz(x_tz); - - for (size_t i = 0; i < paddings.size() / 2; ++i) { - out_tz[out_tz.size() - 1 - i] += paddings[2 * i] + paddings[2 * i + 1]; - } - out->Resize(phi::make_ddim(out_tz)); - - auto paddle_dtype = framework::TransToProtoVarType(x->dtype()); - - platform::ReorderMKLDNNHandler reorder_handler( - x_tz, - paddle_dtype, - framework::ToMKLDNNDataType(paddle_dtype), - onednn_engine); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - x->mem_desc(), platform::to_void_cast(x->data())); - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - out, - out_tz, - platform::GetPlainMKLDNNFormat(out_tz.size()), - ctx.GetPlace()); - - // to avoid allocating new temporary memory, Out's memory is used as a tmp - // buffer for storing a contiguous memory consisting of pad_value, which - // later is used as a SRC for reorders that are filling Out with padding - T* out_ptr = out->data(); - std::fill(out_ptr, - out_ptr + CalculateNumOfPrefillElems(out_tz, paddings), - pad_value); - - // paddings are in order: left, right, top, bottom, front, back - for (size_t i = 0; i < paddings.size(); ++i) { - if (paddings[i] != 0) { - std::vector offsets(out_tz.size(), 0); - std::vector chunk_tz(out_tz.begin(), out_tz.end()); - - chunk_tz[out_tz.size() - 1 - i / 2] = paddings[i]; - if (i % 2 == 1) { - offsets[out_tz.size() - 1 - i / 2] = - paddings[i - 1] + x_tz[out_tz.size() - 1 - i / 2]; - } - - FillPartOfPadding(paddle_dtype, - onednn_engine, - out_ptr, - reorder_dst_memory_p, - chunk_tz, - offsets); - } - } - astream.wait(); - - std::vector offsets(out_tz.size(), 0); - for (size_t i = 0; i < paddings.size() / 2; ++i) { - offsets[out_tz.size() - 1 - i] = paddings[2 * i]; - } - - auto slice_mem_p = - reorder_handler.AcquireSubmemory(x_tz, offsets, reorder_dst_memory_p); - - auto reorder_p = - reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); - reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); - astream.wait(); - - out->set_mem_desc(reorder_dst_memory_p->get_desc()); - } - - int64_t CalculateNumOfPrefillElems(const std::vector& out_tz, - const std::vector& paddings) const { - int64_t max_elems = 0; - int64_t independent_dims = out_tz[0] * out_tz[1]; - - for (size_t i = 0; i < paddings.size() / 2; ++i) { - int64_t elems = std::max(paddings[2 * i], paddings[2 * i + 1]); - for (size_t j = 0; j < paddings.size() / 2; ++j) { - if (j != i) { - elems *= out_tz[out_tz.size() - 1 - j]; - } - } - - if (max_elems < elems) { - max_elems = elems; - } - } - return independent_dims * max_elems; - } - - void FillPartOfPadding(framework::proto::VarType::Type paddle_dtype, - const dnnl::engine& onednn_engine, - T* prefilled_mem_ptr, - const std::shared_ptr& out_mem_p, - const std::vector& chunk_tz, - const std::vector& offsets) const { - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - dnnl::memory::desc prefilled_mem_desc( - chunk_tz, - platform::MKLDNNGetDataType(), - platform::GetPlainMKLDNNFormat(chunk_tz.size())); - dnnl::memory prefilled_mem( - prefilled_mem_desc, onednn_engine, prefilled_mem_ptr); - - dnnl::memory::desc out_slice_md = - out_mem_p->get_desc().submemory_desc(chunk_tz, {offsets}); - dnnl::memory out_slice_mem( - out_slice_md, onednn_engine, out_mem_p->get_data_handle()); - - auto reorder_p = dnnl::reorder(prefilled_mem, out_slice_mem); - reorder_p.execute(astream, prefilled_mem, out_slice_mem); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(pad3d, - MKLDNN, - paddle::platform::CPUPlace, - ops::PadMKLDNNKernel); - -REGISTER_OP_KERNEL(pad2d, - MKLDNN, - paddle::platform::CPUPlace, - ops::PadMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc deleted file mode 100644 index a7c6bd2848..0000000000 --- a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -template -class SliceMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("Input"); - auto* out = ctx.Output("Out"); - - auto x_vec_dims = phi::vectorize(x->dims()); - - auto axes_int = ctx.Attr>("axes"); - auto starts_int = ctx.Attr>("starts"); - auto ends_int = ctx.Attr>("ends"); - - std::vector axes(ctx.Attr>("axes").begin(), - ctx.Attr>("axes").end()); - std::vector starts(ctx.Attr>("starts").begin(), - ctx.Attr>("starts").end()); - std::vector ends(ctx.Attr>("ends").begin(), - ctx.Attr>("ends").end()); - - auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); - if (ctx.HasInput("StartsTensor")) { - starts = GetDataFromTensor(ctx.Input("StartsTensor")); - } else if (starts_tensor_list.size() > 0) { - starts = GetDataFromTensorList(starts_tensor_list); - } - - auto decrease_axis = ctx.Attr>("decrease_axis"); - - auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); - if (ctx.HasInput("EndsTensor")) { - ends = GetDataFromTensor(ctx.Input("EndsTensor")); - } else if (ends_tensor_list.size() > 0) { - ends = GetDataFromTensorList(ends_tensor_list); - } - - std::vector offsets(x_vec_dims.size(), 0); - std::vector slice_dims(x_vec_dims); - - for (size_t i = 0; i < axes.size(); ++i) { - starts[i] = starts[i] < 0 ? x_vec_dims[axes[i]] + starts[i] : starts[i]; - ends[i] = ends[i] < 0 ? x_vec_dims[axes[i]] + ends[i] - : std::min(ends[i], x_vec_dims[axes[i]]); - offsets[axes[i]] = starts[i]; - slice_dims[axes[i]] = - std::max(static_cast(0), ends[i] - starts[i]); - } - - out->Resize(phi::make_ddim(slice_dims)); - - // Note(0x45f): To support slice Tensors with shapes like [0, 0, 0]. - if (!x->initialized()) { - out->mutable_data(x->place(), x->dtype()); - out->set_layout(experimental::DataLayout::kMKLDNN); - return; - } - - dnnl::memory::data_type x_type = - framework::ToMKLDNNDataType(framework::TransToProtoVarType(x->dtype())); - - platform::ReorderMKLDNNHandler reorder_handler( - x_vec_dims, - framework::TransToProtoVarType(x->dtype()), - x_type, - onednn_engine); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - x->mem_desc(), platform::to_void_cast(x->data())); - auto slice_mem_p = reorder_handler.AcquireSubmemory( - slice_dims, offsets, reorder_src_memory_p); - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - out, - slice_dims, - platform::GetPlainMKLDNNFormat(x_vec_dims.size()), - ctx.GetPlace()); - - auto reorder_p = - reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); - - std::vector new_out_dims(slice_dims.size() - decrease_axis.size()); - - if (new_out_dims.size() == 0) { - new_out_dims.emplace_back(1); - } else { - for (const auto& axis : decrease_axis) { - slice_dims[axis] = 0; - } - - int i = 0; - for (const auto& slice_dim : slice_dims) { - if (slice_dim != 0) new_out_dims[i++] = slice_dim; - } - } - - astream.wait(); - out->Resize(phi::make_ddim(new_out_dims)); - out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(new_out_dims)); - } -}; -template -class SliceGradMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("Input")); - - auto dx_vec_dims = phi::vectorize(dx->dims()); - auto dout_vec_dims = phi::vectorize(dout->dims()); - - auto axes_int = ctx.Attr>("axes"); - auto starts_int = ctx.Attr>("starts"); - auto ends_int = ctx.Attr>("ends"); - - std::vector axes(ctx.Attr>("axes").begin(), - ctx.Attr>("axes").end()); - std::vector starts(ctx.Attr>("starts").begin(), - ctx.Attr>("starts").end()); - std::vector ends(ctx.Attr>("ends").begin(), - ctx.Attr>("ends").end()); - - auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); - if (ctx.HasInput("StartsTensor")) { - starts = GetDataFromTensor(ctx.Input("StartsTensor")); - } else if (starts_tensor_list.size() > 0) { - starts = GetDataFromTensorList(starts_tensor_list); - } - - auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); - if (ctx.HasInput("EndsTensor")) { - ends = GetDataFromTensor(ctx.Input("EndsTensor")); - } else if (ends_tensor_list.size() > 0) { - ends = GetDataFromTensorList(ends_tensor_list); - } - - auto decrease_axis = ctx.Attr>("decrease_axis"); - - std::vector offsets(dx_vec_dims.size(), 0); - std::vector slice_dims(dx_vec_dims); - - for (size_t i = 0; i < axes.size(); ++i) { - starts[i] = starts[i] < 0 ? dx_vec_dims[axes[i]] + starts[i] : starts[i]; - ends[i] = ends[i] < 0 ? dx_vec_dims[axes[i]] + ends[i] - : std::min(ends[i], dx_vec_dims[axes[i]]); - offsets[axes[i]] = starts[i]; - slice_dims[axes[i]] = ends[i] - starts[i]; - } - - dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(dout->dtype())); - - platform::ReorderMKLDNNHandler reorder_handler( - slice_dims, - framework::TransToProtoVarType(dout->dtype()), - dout_type, - onednn_engine); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - dout->mem_desc().reshape(slice_dims), - platform::to_void_cast(dout->data())); - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - dx, - dx_vec_dims, - platform::GetPlainMKLDNNFormat(dx_vec_dims.size()), - ctx.GetPlace()); - memset(dx->data(), 0, reorder_dst_memory_p->get_desc().get_size()); - - auto slice_mem_p = reorder_handler.AcquireSubmemory( - slice_dims, offsets, reorder_dst_memory_p); - - auto reorder_p = - reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); - astream.wait(); - - dx->set_mem_desc(reorder_dst_memory_p->get_desc()); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(slice, - MKLDNN, - paddle::platform::CPUPlace, - ops::SliceMKLDNNKernel, - ops::SliceMKLDNNKernel, - ops::SliceMKLDNNKernel, - ops::SliceMKLDNNKernel); - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(slice_grad, - MKLDNN, - paddle::platform::CPUPlace, - ops::SliceGradMKLDNNKernel, - ops::SliceGradMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc deleted file mode 100644 index f71931ad1e..0000000000 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -static inline std::vector> CalculateOutsDims( - const framework::DDim& in_dims, - const size_t num, - const std::vector& sections, - const size_t axis, - const int outs_number) { - std::vector> outs_dims(outs_number, - phi::vectorize(in_dims)); - - if (num > 0) { - PADDLE_ENFORCE_EQ(in_dims[axis] % num, - 0, - platform::errors::InvalidArgument( - "The input's size along the split dimension " - "must be evenly divisible by Attr(num_or_sections). " - "But received Attr(num_or_sections) " - "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", - num, - in_dims, - axis)); - - const size_t out_axis_dim = in_dims[axis] / num; - - for (auto& out_dim : outs_dims) out_dim[axis] = out_axis_dim; - } else { - for (size_t i = 0; i < outs_dims.size(); ++i) - outs_dims[i][axis] = sections[i]; - } - return outs_dims; -} - -template -class SplitMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - const auto* x = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - - int num = ctx.Attr("num"); - auto sections = ctx.Attr>("sections"); - int axis = ctx.Attr("axis"); - auto outs_number = outs.size(); - const auto x_dims = x->dims(); - - bool need_resize = false; - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - need_resize = true; - } - - auto sections_tensor_list = ctx.MultiInput("SectionsTensorList"); - if (sections_tensor_list.size() > 0) { - sections = GetDataFromTensorList(sections_tensor_list); - need_resize = true; - } - - if (need_resize) { - const auto outs_dims = - CalculateOutsDims(x->dims(), num, sections, axis, outs_number); - for (size_t i = 0; i < outs.size(); ++i) { - outs[i]->Resize(phi::make_ddim(outs_dims[i])); - } - } - - auto x_vec_dims = phi::vectorize(x_dims); - - dnnl::memory::data_type x_type = - framework::ToMKLDNNDataType(framework::TransToProtoVarType(x->dtype())); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - std::vector offset(x_vec_dims.size(), 0); - - platform::ReorderMKLDNNHandler reorder_handler( - x_vec_dims, - framework::TransToProtoVarType(x->dtype()), - x_type, - onednn_engine); - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - x->mem_desc(), platform::to_void_cast(x->data())); - - for (size_t i = 0; i < outs_number; ++i) { - auto out_vec_dims = phi::vectorize(outs[i]->dims()); - auto slice_mem_p = reorder_handler.AcquireSubmemory( - out_vec_dims, offset, reorder_src_memory_p); - - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - outs[i], out_vec_dims, x->format(), ctx.GetPlace()); - auto reorder_p = - reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); - - reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); - - offset[axis] += num > 0 ? x->dims()[axis] / num : sections[i]; - - outs[i]->set_mem_desc(reorder_dst_memory_p->get_desc()); - } - astream.wait(); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(split, - MKLDNN, - paddle::platform::CPUPlace, - ops::SplitMKLDNNKernel, - ops::SplitMKLDNNKernel); diff --git a/paddle/phi/backends/onednn/onednn_helper.h b/paddle/phi/backends/onednn/onednn_helper.h index aeaecf7491..e91e02282c 100644 --- a/paddle/phi/backends/onednn/onednn_helper.h +++ b/paddle/phi/backends/onednn/onednn_helper.h @@ -96,29 +96,29 @@ inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { } template -dnnl::memory::data_type oneDNNGetDataType() { +dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::undef; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::f32; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::s32; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::s8; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::u8; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::bf16; } diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 66376dd883..6b806748d0 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -834,7 +834,7 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT { src0_md = src0_md.reshape(dims0_ex); } const auto dst_md = - memory::desc(dst_tz, oneDNNGetDataType(), OneDNNMemoryFormat::any); + memory::desc(dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops); @@ -905,7 +905,7 @@ class BroadcastDataOneDNNHandler : OneDNNHandlerNoCachingT(engine, cpu_place) { const auto src0_tz = vectorize(out->dims()); const auto src0_md = dnnl::memory::desc( - src0_tz, oneDNNGetDataType(), GetPlainOneDNNFormat(src0_tz.size())); + src0_tz, OneDNNGetDataType(), GetPlainOneDNNFormat(src0_tz.size())); const auto src1_md = x->mem_desc().reshape(extended_x_dims); dnnl::primitive_attr attributes; @@ -940,7 +940,7 @@ class ReductionOneDNNHandler const dnnl::primitive_attr& attrs = NULL) : OneDNNHandlerNoCachingT(engine, cpu_place) { const auto out_md = memory::desc( - out_tz, oneDNNGetDataType(), dnnl::memory::format_tag::any); + out_tz, OneDNNGetDataType(), dnnl::memory::format_tag::any); if (attrs) this->AcquireForwardPrimitiveDescriptor( @@ -1144,7 +1144,7 @@ class PoolingOneDNNHandler const auto dt = ToOneDNNDataType(in_x->dtype()); auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, OneDNNMemoryFormat::any); auto diff_src_md = dnnl::memory::desc( - diff_src_tz, oneDNNGetDataType(), OneDNNMemoryFormat::any); + diff_src_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); auto onednn_paddings = ToOneDNNPadding(copied_paddings); diff --git a/paddle/phi/kernels/onednn/pad3d_kernel.cc b/paddle/phi/kernels/onednn/pad3d_kernel.cc new file mode 100644 index 0000000000..2d34e11afc --- /dev/null +++ b/paddle/phi/kernels/onednn/pad3d_kernel.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad3d_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/onednn/pad_kernel_impl.h" + +namespace phi { + +template +void Pad3dKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* out) { + PadOpKernel(dev_ctx, x, paddings.GetData(), pad_value, out); +} +} // namespace phi + +PD_REGISTER_KERNEL(pad3d, OneDNN, ALL_LAYOUT, phi::Pad3dKernel, float) {} diff --git a/paddle/phi/kernels/onednn/pad_kernel.cc b/paddle/phi/kernels/onednn/pad_kernel.cc new file mode 100644 index 0000000000..4177f000db --- /dev/null +++ b/paddle/phi/kernels/onednn/pad_kernel.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/onednn/pad_kernel_impl.h" + +namespace phi { + +template +void PadKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& paddings, + const Scalar& pad_value, + DenseTensor* out) { + std::vector copied_paddings(paddings.begin(), paddings.end()); + + std::swap(copied_paddings[0], copied_paddings[2]); + std::swap(copied_paddings[1], copied_paddings[3]); + PadOpKernel( + dev_ctx, x, copied_paddings, pad_value.to(), out); +} +} // namespace phi + +PD_REGISTER_KERNEL(pad, OneDNN, ALL_LAYOUT, phi::PadKernel, float) {} diff --git a/paddle/phi/kernels/onednn/pad_kernel_impl.h b/paddle/phi/kernels/onednn/pad_kernel_impl.h new file mode 100644 index 0000000000..eabe18855b --- /dev/null +++ b/paddle/phi/kernels/onednn/pad_kernel_impl.h @@ -0,0 +1,177 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/onednn/onednn_reuse.h" + +namespace phi { + +/* +Pad3D is done by using up to 7 reorders. Following example is done +on 2D data for simplicity, but it is straightforward to extend it to 3D case. + +Let us consider following example: + + N C H W L R T B +X_dims = (1, 1, 3, 3), paddings = (1, 2, 3, 4) in order Left, Right, Top, Bottom + +We have to copy the X tensor into Out tensor, but except from that we have to +fill the rest of the memory with an additional padding. To avoid looping through +the whole Out memory two times, only these parts of Out memory that won't store +X's memory are filled with pad value. That behavior is achieved by using +oneDNN's submemory descriptors which allows us to set offsets for each dimension +and skip some parts of the memory. For 2D case up to 5 reorders will be used in +Pad3D kernel(if padding=0 reorder is skipped). In the following example i'th +number means, that this part of memory was filled by i'th reorder. 4'th reorder +is copying X memory into Out memory. i&j means that both i'th and j'th reorder +will set the padding at that location: + + INDEX + | 0 1 2 3 4 5 + |_______________________ + 0 |0&2 2 2 2 1&2 1&2 + 1 |0&2 2 2 2 1&2 1&2 +I 2 |0&2 2 2 2 1&2 1&2 +N 3 | 0 4 4 4 1 1 +D 4 | 0 4 4 4 1 1 +E 5 | 0 4 4 4 1 1 +X 6 |0&3 3 3 3 1&3 1&3 + 7 |0&3 3 3 3 1&3 1&3 + 8 |0&3 3 3 3 1&3 1&3 + 9 |0&3 3 3 3 1&3 1&3 + +Since oneDNN's reorder cannot set the pad value to the memory by itself, we have +to prefill Out's memory and use it as a temporary buffer, which later is copied +into the rest of Out's memory. At the end last reorder is done which copies X +memory into Out memory. + +*/ + +inline int64_t CalculateNumOfPrefillElems( + const std::vector& out_tz, const std::vector& paddings) { + int64_t max_elems = 0; + int64_t independent_dims = out_tz[0] * out_tz[1]; + + for (size_t i = 0; i < paddings.size() / 2; ++i) { + int64_t elems = std::max(paddings[2 * i], paddings[2 * i + 1]); + for (size_t j = 0; j < paddings.size() / 2; ++j) { + if (j != i) { + elems *= out_tz[out_tz.size() - 1 - j]; + } + } + + if (max_elems < elems) { + max_elems = elems; + } + } + return independent_dims * max_elems; +} + +template +void FillPartOfPadding(const dnnl::engine& onednn_engine, + T* prefilled_mem_ptr, + const std::shared_ptr& out_mem_p, + const std::vector& chunk_tz, + const std::vector& offsets) { + auto& astream = OneDNNContext::tls().get_stream(); + + dnnl::memory::desc prefilled_mem_desc( + chunk_tz, + funcs::OneDNNGetDataType(), + funcs::GetPlainOneDNNFormat(chunk_tz.size())); + dnnl::memory prefilled_mem( + prefilled_mem_desc, onednn_engine, prefilled_mem_ptr); + + dnnl::memory::desc out_slice_md = + out_mem_p->get_desc().submemory_desc(chunk_tz, {offsets}); + dnnl::memory out_slice_mem( + out_slice_md, onednn_engine, out_mem_p->get_data_handle()); + + auto reorder_p = dnnl::reorder(prefilled_mem, out_slice_mem); + reorder_p.execute(astream, prefilled_mem, out_slice_mem); +} + +template +void PadOpKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& paddings, + float pad_value, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + auto& astream = OneDNNContext::tls().get_stream(); + + std::vector x_tz = vectorize(x.dims()); + // due to the need of supporting NDHWC, inferring out shape + // must be done inside the kernel + std::vector out_tz(x_tz); + + for (size_t i = 0; i < paddings.size() / 2; ++i) { + out_tz[out_tz.size() - 1 - i] += paddings[2 * i] + paddings[2 * i + 1]; + } + out->Resize(make_ddim(out_tz)); + + funcs::ReorderOneDNNHandler reorder_handler( + x_tz, x.dtype(), funcs::ToOneDNNDataType(x.dtype()), onednn_engine); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, + out_tz, + funcs::GetPlainOneDNNFormat(out_tz.size()), + dev_ctx.GetPlace()); + + // to avoid allocating new temporary memory, Out's memory is used as a tmp + // buffer for storing a contiguous memory consisting of pad_value, which + // later is used as a SRC for reorders that are filling Out with padding + T* out_ptr = out->data(); + std::fill(out_ptr, + out_ptr + CalculateNumOfPrefillElems(out_tz, paddings), + pad_value); + + // paddings are in order: left, right, top, bottom, front, back + for (size_t i = 0; i < paddings.size(); ++i) { + if (paddings[i] != 0) { + std::vector offsets(out_tz.size(), 0); + std::vector chunk_tz(out_tz.begin(), out_tz.end()); + + chunk_tz[out_tz.size() - 1 - i / 2] = paddings[i]; + if (i % 2 == 1) { + offsets[out_tz.size() - 1 - i / 2] = + paddings[i - 1] + x_tz[out_tz.size() - 1 - i / 2]; + } + + FillPartOfPadding( + onednn_engine, out_ptr, reorder_dst_memory_p, chunk_tz, offsets); + } + } + astream.wait(); + + std::vector offsets(out_tz.size(), 0); + for (size_t i = 0; i < paddings.size() / 2; ++i) { + offsets[out_tz.size() - 1 - i] = paddings[2 * i]; + } + + auto slice_mem_p = + reorder_handler.AcquireSubmemory(x_tz, offsets, reorder_dst_memory_p); + + auto reorder_p = + reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); + reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); + astream.wait(); + + out->set_mem_desc(reorder_dst_memory_p->get_desc()); +} +} // namespace phi diff --git a/paddle/phi/kernels/onednn/slice_grad_kernel.cc b/paddle/phi/kernels/onednn/slice_grad_kernel.cc new file mode 100644 index 0000000000..c38a2237e5 --- /dev/null +++ b/paddle/phi/kernels/onednn/slice_grad_kernel.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SliceGradRawKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* input_grad) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto dx_dims = vectorize(input_grad->dims()); + + auto starts_vec = starts.GetData(); + auto ends_vec = ends.GetData(); + + std::vector offsets(dx_dims.size(), 0); + std::vector slice_dims(dx_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + starts_vec[i] = + starts_vec[i] < 0 ? dx_dims[axes[i]] + starts_vec[i] : starts_vec[i]; + ends_vec[i] = ends_vec[i] < 0 ? dx_dims[axes[i]] + ends_vec[i] + : std::min(ends_vec[i], dx_dims[axes[i]]); + offsets[axes[i]] = starts_vec[i]; + slice_dims[axes[i]] = ends_vec[i] - starts_vec[i]; + } + + dnnl::memory::data_type out_grad_type = + funcs::ToOneDNNDataType(out_grad.dtype()); + + funcs::ReorderOneDNNHandler reorder_handler( + slice_dims, out_grad.dtype(), out_grad_type, onednn_engine); + + auto reorder_src_memory_p = + reorder_handler.AcquireSrcMemory(out_grad.mem_desc().reshape(slice_dims), + funcs::to_void_cast(out_grad.data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + input_grad, + dx_dims, + funcs::GetPlainOneDNNFormat(dx_dims.size()), + dev_ctx.GetPlace()); + memset(input_grad->data(), 0, reorder_dst_memory_p->get_desc().get_size()); + + auto slice_mem_p = reorder_handler.AcquireSubmemory( + slice_dims, offsets, reorder_dst_memory_p); + + auto reorder_p = + reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); + astream.wait(); + + input_grad->set_mem_desc(reorder_dst_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(slice_grad, + OneDNN, + ALL_LAYOUT, + phi::SliceGradRawKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/slice_kernel.cc b/paddle/phi/kernels/onednn/slice_kernel.cc new file mode 100644 index 0000000000..3f74a2fe0b --- /dev/null +++ b/paddle/phi/kernels/onednn/slice_kernel.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SliceRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto x_vec_dims = vectorize(x.dims()); + + auto starts_vec = starts.GetData(); + auto ends_vec = ends.GetData(); + + std::vector offsets(x_vec_dims.size(), 0); + std::vector slice_dims(x_vec_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + starts_vec[i] = + starts_vec[i] < 0 ? x_vec_dims[axes[i]] + starts_vec[i] : starts_vec[i]; + ends_vec[i] = ends_vec[i] < 0 ? x_vec_dims[axes[i]] + ends_vec[i] + : std::min(ends_vec[i], x_vec_dims[axes[i]]); + offsets[axes[i]] = starts_vec[i]; + slice_dims[axes[i]] = + std::max(static_cast(0), ends_vec[i] - starts_vec[i]); + } + + out->Resize(make_ddim(slice_dims)); + + // Note(0x45f): To support slice Tensors with shapes like [0, 0, 0]. + if (!x.initialized()) { + dev_ctx.Alloc(out, x.dtype()); + out->set_layout(DataLayout::ONEDNN); + return; + } + + dnnl::memory::data_type x_type = funcs::ToOneDNNDataType(x.dtype()); + + funcs::ReorderOneDNNHandler reorder_handler( + x_vec_dims, x.dtype(), x_type, onednn_engine); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + auto slice_mem_p = reorder_handler.AcquireSubmemory( + slice_dims, offsets, reorder_src_memory_p); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, + slice_dims, + funcs::GetPlainOneDNNFormat(x_vec_dims.size()), + dev_ctx.GetPlace()); + + auto reorder_p = + reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + + std::vector new_out_dims(slice_dims.size() - decrease_axis.size()); + + if (new_out_dims.size() == 0) { + new_out_dims.emplace_back(1); + } else { + for (const auto& axis : decrease_axis) { + slice_dims[axis] = 0; + } + + int i = 0; + for (const auto& slice_dim : slice_dims) { + if (slice_dim != 0) new_out_dims[i++] = slice_dim; + } + } + + astream.wait(); + out->Resize(make_ddim(new_out_dims)); + out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(new_out_dims)); +} + +} // namespace phi + +PD_REGISTER_KERNEL(slice, + OneDNN, + ALL_LAYOUT, + phi::SliceRawKernel, + float, + int8_t, + uint8_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/split_kernel.cc b/paddle/phi/kernels/onednn/split_kernel.cc new file mode 100644 index 0000000000..1d0544758f --- /dev/null +++ b/paddle/phi/kernels/onednn/split_kernel.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/split_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& sections, + const Scalar& split_axis, + std::vector out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + int axis = split_axis.to(); + + auto outs_number = out.size(); + const auto x_dims = x.dims(); + auto x_vec_dims = vectorize(x_dims); + + dnnl::memory::data_type x_type = funcs::ToOneDNNDataType(x.dtype()); + + auto& astream = OneDNNContext::tls().get_stream(); + + std::vector offset(x_vec_dims.size(), 0); + funcs::ReorderOneDNNHandler reorder_handler( + x_vec_dims, x.dtype(), x_type, onednn_engine); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + + for (size_t i = 0; i < outs_number; ++i) { + auto out_vec_dims = vectorize(out[i]->dims()); + auto slice_mem_p = reorder_handler.AcquireSubmemory( + out_vec_dims, offset, reorder_src_memory_p); + + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out[i], out_vec_dims, x.format(), dev_ctx.GetPlace()); + auto reorder_p = + reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); + + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + + offset[axis] += sections.GetData()[i]; + out[i]->set_mem_desc(reorder_dst_memory_p->get_desc()); + } + astream.wait(); +} + +template +void SplitWithNumKernel(const Context& dev_ctx, + const DenseTensor& x, + int num, + const Scalar& axis_scalar, + std::vector outs) { + int axis_value = axis_scalar.to(); + auto input_axis_dim = x.dims().at(axis_value); + std::vector sections_vec; + for (int i = 0; i < num; ++i) { + sections_vec.push_back(input_axis_dim / num); + } + IntArray sections(sections_vec); + SplitKernel(dev_ctx, x, sections, axis_scalar, outs); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + split, OneDNN, ALL_LAYOUT, phi::SplitKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(split_with_num, + OneDNN, + ALL_LAYOUT, + phi::SplitWithNumKernel, + float, + phi::dtype::bfloat16) {} -- GitLab