From fdd0d6d03313ae1baa75aefa625fd29634a17a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 10 Oct 2022 10:09:45 +0200 Subject: [PATCH] =?UTF-8?q?[cherry-pick]=20[PHI]=20Migrate=20concat+grad,?= =?UTF-8?q?=20expand+grad,=20fill=5Fconstant=20=E2=80=A6=20oneDNN=20kernel?= =?UTF-8?q?s=20(#45863)=20(#46727)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [PHI] Migrate concat+grad, expand+grad, fill_constant, nearest_interp and bilinear_interp oneDNN kernels (#45863) * Migrate concat+grad, expand+grad, fill_constant, nearest_interp_v2 and bilinear_interp_v2 oneDNN kernels to PHI * Remove old namespace variable * Fix invalid out dims error * Add mutable_data method to concat output * Add check for -1 dim before computing out_dims * Capitalize oneDNNGetDataType function name * Change fill_constant kernel to correct PHI kernel * Attempt to fix dims error * Fix fill_constant (full) kernel * update dependencies Co-authored-by: Piotr Paturej <48731682+piotrekobi@users.noreply.github.com> --- .../operators/mkldnn/concat_mkldnn_op.cc | 264 ------------------ .../operators/mkldnn/expand_v2_mkldnn_op.cc | 177 ------------ .../mkldnn/fill_constant_mkldnn_op.cc | 140 ---------- .../operators/mkldnn/interpolate_mkldnn_op.cc | 12 - .../phi/kernels/onednn/concat_grad_kernel.cc | 84 ++++++ paddle/phi/kernels/onednn/concat_kernel.cc | 164 +++++++++++ .../phi/kernels/onednn/expand_grad_kernel.cc | 92 ++++++ paddle/phi/kernels/onednn/expand_kernel.cc | 83 ++++++ paddle/phi/kernels/onednn/full_kernel.cc | 88 ++++++ .../phi/kernels/onednn/interpolate_kernel.cc | 240 ++++++++++++++++ 10 files changed, 751 insertions(+), 593 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc delete mode 100644 paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc delete mode 100644 paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc create mode 100644 paddle/phi/kernels/onednn/concat_grad_kernel.cc create mode 100644 paddle/phi/kernels/onednn/concat_kernel.cc create mode 100644 paddle/phi/kernels/onednn/expand_grad_kernel.cc create mode 100644 paddle/phi/kernels/onednn/expand_kernel.cc create mode 100644 paddle/phi/kernels/onednn/full_kernel.cc create mode 100644 paddle/phi/kernels/onednn/interpolate_kernel.cc diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc deleted file mode 100644 index b16576505df..00000000000 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ /dev/null @@ -1,264 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/concat_op.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using dnnl::concat; -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using framework::DataLayout; -using framework::LoDTensor; -using framework::Tensor; -using platform::to_void_cast; - -template -class ConcatMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - ConcatMKLDNNHandler(const framework::ExecutionContext& ctx, - const dnnl::engine mkldnn_engine, - const std::vector& inputs, - Tensor* output) - : platform::MKLDNNHandlerNoCachingT(mkldnn_engine, - ctx.GetPlace()) { - int concat_axis = ctx.Attr("axis"); - const int rank = inputs[0]->dims().size(); - PADDLE_ENFORCE_EQ( - concat_axis >= -rank && concat_axis < rank, - true, - platform::errors::InvalidArgument( - "The axis is expected to be in range of [%d, %d), but got %d", - -rank, - rank, - concat_axis)); - - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - concat_axis = GetDataFromTensor(axis_tensor)[0]; - auto out_dims = inputs[0]->dims(); - for (size_t i = 1; i < inputs.size(); ++i) { - out_dims[concat_axis] += inputs[i]->dims()[concat_axis]; - } - output->Resize(out_dims); - } - - if (concat_axis < 0) { - concat_axis = concat_axis + rank; - } - - memory::data_type dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(inputs[0]->dtype())); - std::vector srcs_md; - srcs_md.reserve(inputs.size()); - - // Create memory descriptors for each of inputs - for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.push_back(inputs[i]->mem_desc()); - } - - auto dst_dims = phi::vectorize(output->dims()); - - dnnl::memory::desc dst_md; - - // if concat is being used as a stack op(all source memories dims on - // concat_axis are equal to 1), then it may choose a non-optimal memory - // format tag for destination, because concat primitive is chosing it based - // on source memory descriptors and f.e.200x1x10 can be described as both - // abc and bac and both would be using exact same physical layout, but in - // that scenario bac will be chosen for destination no matter which - // formats are being set in inputs. In that scenario we are enforcing using - // a dense format, because it is the most common one and should be the best - // in terms of the performance - const auto src0_tz = srcs_md[0].dims(); - if (std::find(src0_tz.begin(), src0_tz.end(), 1) != src0_tz.end()) { - dst_md = memory::desc( - dst_dims, dt, platform::GetPlainMKLDNNFormat(dst_dims.size())); - } else { - dst_md = memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); - } - - this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); - } - - // (jczaja) concat oneDNN prim is not having .desc attribute so - // we cannot use base AcquireForwardPrimitiveDescriptor - void AcquireForwardPrimitiveDescriptor( - const memory::desc& dst_md, - const int concat_axis, - const std::vector& srcs_md) { - this->fwd_pd_.reset(new dnnl::concat::primitive_desc( - dst_md, concat_axis, srcs_md, this->engine_)); - } - - std::shared_ptr AcquireSrcMemory(const Tensor& input, int i) { - const T* input_data = input.data(); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), - to_void_cast(input_data)); - } -}; - -static void EnforceLayouts(const std::vector inputs) { - for (auto* input : inputs) { - PADDLE_ENFORCE_EQ( - input->layout(), - DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for Input tensor")); - } -} - -// From a multi-input, gather only nonempty inputs -static const std::vector ReduceMultiInput( - const std::vector& inputs) { - std::vector reduced(inputs.size()); - auto end_it = std::copy_if( - inputs.begin(), inputs.end(), reduced.begin(), [](const Tensor* t) { - return t->numel() > 0; - }); - reduced.resize(std::distance(reduced.begin(), end_it)); - return reduced; -} - -template -class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - // If any of the multiple inputs of concat has an input size of 0, the - // actual size of the multi_input will change - auto multi_input = ReduceMultiInput(ctx.MultiInput("X")); - EnforceLayouts(multi_input); - Tensor* output = ctx.Output("Out"); - - ConcatMKLDNNHandler handler(ctx, mkldnn_engine, multi_input, output); - - std::vector> srcs; - srcs.reserve(multi_input.size()); - - auto dst_mem = handler.AcquireDstMemory(output); - auto concat_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - std::unordered_map args; - for (size_t i = 0; i < multi_input.size(); ++i) { - srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); - args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); - } - args.insert({DNNL_ARG_DST, *dst_mem}); - - concat_p->execute(astream, args); - astream.wait(); - - output->set_mem_desc(dst_mem->get_desc()); - } -}; - -template -class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); - - const auto x = ctx.MultiInput("X"); - const auto* dout = ctx.Input(framework::GradVarName("Out")); - auto dx = ctx.MultiOutput(framework::GradVarName("X")); - - for (size_t i = 0; i < dx.size(); ++i) { - if (dx[i] != nullptr) { - dx[i]->set_lod(x[i]->lod()); - } - } - - int axis = ctx.Attr("axis"); - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - } - - auto dout_vec_dims = phi::vectorize(dout->dims()); - - axis = ComputeAxis(axis, dout_vec_dims.size()); - - std::vector offset(dout_vec_dims.size(), 0); - - dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(dout->dtype())); - platform::ReorderMKLDNNHandler reorder_handler( - dout_vec_dims, - framework::TransToProtoVarType(dout->dtype()), - dout_type, - onednn_engine); - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - dout->mem_desc(), platform::to_void_cast(dout->data())); - - for (size_t i = 0; i < dx.size(); ++i) { - if (out_var_names[i] != framework::kEmptyVarName && - dx[i]->numel() != 0UL) { - auto dx_vec_dims = phi::vectorize(dx[i]->dims()); - auto slice_mem_p = reorder_handler.AcquireSubmemory( - dx_vec_dims, offset, reorder_src_memory_p); - - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - dx[i], - dx_vec_dims, - platform::GetPlainMKLDNNFormat(dx_vec_dims.size()), - ctx.GetPlace()); - auto reorder_p = - reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); - - reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); - - offset[axis] += dx[i]->dims()[axis]; - - dx[i]->set_mem_desc(reorder_dst_memory_p->get_desc()); - } - } - astream.wait(); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(concat, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::ConcatMKLDNNOpKernel, - ops::ConcatMKLDNNOpKernel, - ops::ConcatMKLDNNOpKernel, - ops::ConcatMKLDNNOpKernel); - -REGISTER_OP_KERNEL(concat_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::ConcatGradMKLDNNOpKernel, - ops::ConcatGradMKLDNNOpKernel); diff --git a/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc deleted file mode 100644 index d477fa0b2bf..00000000000 --- a/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/operators/expand_v2_op.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace { - -using paddle::framework::ExecutionContext; -using paddle::framework::GradVarName; -using paddle::framework::Tensor; -using paddle::platform::MKLDNNDeviceContext; -using phi::vectorize; - -template -class ExpandMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - const auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - auto x_vec_dims = vectorize(x->dims()); - - auto out_new_dims = paddle::operators::get_expand_shape(ctx); - for (size_t i = 0; i < out_new_dims.size(); ++i) { - out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i]; - } - - if (x_vec_dims.size() != out_new_dims.size()) { - x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size()); - } - - out->Resize(phi::make_ddim(out_new_dims)); - paddle::platform::BroadcastDataMKLDNNHandler handler( - dnnl::algorithm::binary_add, - onednn_engine, - ctx.GetPlace(), - x, - out, - 0.0f, - 1.0f, - x_vec_dims); - - auto src_memory_p = handler.AcquireSrcMemory(x); - auto dst_memory_p = handler.AcquireZeroedDstMemory(out); - auto binary_p = handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *dst_memory_p}, - {DNNL_ARG_SRC_1, *src_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - binary_p->execute(astream, args); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); - } - - private: - std::vector GetExtendedXDims(const std::vector& x_vec_dims, - int new_size) const { - std::vector extended_x_dims(new_size, 1); - std::copy(x_vec_dims.begin(), - x_vec_dims.end(), - extended_x_dims.begin() + new_size - x_vec_dims.size()); - - return extended_x_dims; - } -}; - -template -class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* dout = ctx.Input(GradVarName("Out")); - auto* dx = ctx.Output(GradVarName("X")); - - auto dx_vec_dims = vectorize(dx->dims()); - auto dout_vec_dims = vectorize(dout->dims()); - - if (dx_vec_dims.size() != dout_vec_dims.size()) { - dx_vec_dims.insert( - dx_vec_dims.begin(), dout_vec_dims.size() - dx_vec_dims.size(), 1); - } - - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - if (dout_vec_dims == dx_vec_dims) { - dnnl::memory::data_type dout_type = paddle::framework::ToMKLDNNDataType( - paddle::framework::TransToProtoVarType(dout->dtype())); - paddle::platform::ReorderMKLDNNHandler reorder_handler( - dout_vec_dims, - paddle::framework::TransToProtoVarType(dout->dtype()), - dout_type, - onednn_engine); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - dout->mem_desc(), paddle::platform::to_void_cast(dout->data())); - - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - dx, - paddle::platform::GetPlainMKLDNNFormat(dx_vec_dims.size()), - ctx.GetPlace()); - - auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, - reorder_dst_memory_p); - - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); - - dx->set_mem_desc(reorder_dst_memory_p->get_desc()); - } else { - paddle::platform::ReductionMKLDNNHandler handler( - dnnl::algorithm::reduction_sum, - 0.0f, - 0.0f, - onednn_engine, - ctx.GetPlace(), - dout, - dx, - dx_vec_dims); - - auto src_memory_p = handler.AcquireSrcMemory(dout); - auto dst_memory_p = handler.AcquireDstMemory(dx); - - std::unordered_map reduction_args = { - {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - - auto reduction_p = handler.AcquireForwardPrimitive(); - - reduction_p->execute(astream, reduction_args); - astream.wait(); - dx->set_layout(paddle::framework::DataLayout::kMKLDNN); - dx->set_mem_desc( - dst_memory_p->get_desc().reshape(vectorize(dx->dims()))); - } - } -}; -} // anonymous namespace - -REGISTER_OP_KERNEL(expand_v2, - MKLDNN, - paddle::platform::CPUPlace, - ExpandMKLDNNKernel, - ExpandMKLDNNKernel); - -REGISTER_OP_KERNEL(expand_v2_grad, - MKLDNN, - paddle::platform::CPUPlace, - ExpandGradMKLDNNKernel, - ExpandGradMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc deleted file mode 100644 index 615f43bb32c..00000000000 --- a/paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -class FillConstantMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - FillConstantMKLDNNHandler(Tensor* out, - dnnl::engine engine, - platform::Place cpu_place) - : platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { - const auto src0_md = - dnnl::memory::desc({out->numel(), sizeof(T)}, - platform::MKLDNNGetDataType(), - dnnl::memory::format_tag::ab); - - dnnl::primitive_attr attrs; - attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f}); - - this->AcquireForwardPrimitiveDescriptor( - attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md); - } - - static const dnnl::memory::desc src1_md; -}; - -template -const dnnl::memory::desc FillConstantMKLDNNHandler::src1_md( - {1, sizeof(T)}, - platform::MKLDNNGetDataType(), - dnnl::memory::format_tag::ab); - -template -class FillConstantMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& dnnl_engine = dev_ctx.GetEngine(); - - auto* out = ctx.Output("Out"); - T fill_value = CalculateFillValue(ctx); - - auto shape = GetShape(ctx); - out->Resize(shape); - - FillConstantMKLDNNHandler handler(out, dnnl_engine, ctx.GetPlace()); - - dnnl::memory constant_value_memory = - dnnl::memory(FillConstantMKLDNNHandler::src1_md, - dnnl_engine, - reinterpret_cast(&fill_value)); - - auto src0_memory_p = handler.AcquireDstMemory(out); - auto fill_constant_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - fill_constant_p->execute(astream, - {{DNNL_ARG_SRC_0, *src0_memory_p}, - {DNNL_ARG_SRC_1, constant_value_memory}, - {DNNL_ARG_DST, *src0_memory_p}}); - astream.wait(); - - // src0_memory_p's md was just to allow the usage of a binary - // primitive as a memset, and now we need to create a real one - out->set_mem_desc({phi::vectorize(shape), - platform::MKLDNNGetDataType(), - platform::GetPlainMKLDNNFormat(shape.size())}); - } - - T CalculateFillValue(const framework::ExecutionContext& ctx) const { - const auto str_value = ctx.Attr("str_value"); - const auto float_value = ctx.Attr("value"); - - T value; - - if (str_value.empty()) { - value = static_cast(float_value); - } else { - // handle NaN/Inf first, which cannot be read from stream - if (str_value == "inf") { - value = static_cast(std::numeric_limits::infinity()); - } else if (str_value == "-inf") { - value = static_cast(-std::numeric_limits::infinity()); - } else if (str_value == "nan") { - value = static_cast(std::numeric_limits::quiet_NaN()); - } else { - std::stringstream convert_stream(str_value); - double tmp_value; - convert_stream >> tmp_value; - value = static_cast(tmp_value); - } - } - - if (ctx.HasInput("ValueTensor")) { - const auto* value_tensor = ctx.Input("ValueTensor"); - PADDLE_ENFORCE_EQ( - value_tensor->numel(), - 1, - platform::errors::InvalidArgument( - "When use Tensor as value to set Tensor value in fill_constant, " - "value input(ValueTensor) size must be 1, but got %d", - value_tensor->numel())); - value = value_tensor->data()[0]; - } - - return value; - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(fill_constant, - MKLDNN, - paddle::platform::CPUPlace, - ops::FillConstantMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc index 64d7bca4d06..c5d67e567b2 100644 --- a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc @@ -181,15 +181,3 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel); - -REGISTER_OP_KERNEL(nearest_interp_v2, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::InterpolateMKLDNNKernel, - ops::InterpolateMKLDNNKernel, - ops::InterpolateMKLDNNKernel, - ops::InterpolateMKLDNNKernel); -REGISTER_OP_KERNEL(bilinear_interp_v2, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::InterpolateMKLDNNKernel); diff --git a/paddle/phi/kernels/onednn/concat_grad_kernel.cc b/paddle/phi/kernels/onednn/concat_grad_kernel.cc new file mode 100644 index 00000000000..be962a96aca --- /dev/null +++ b/paddle/phi/kernels/onednn/concat_grad_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto& astream = OneDNNContext::tls().get_stream(); + + for (size_t i = 0; i < x_grad.size(); ++i) { + if (x_grad[i] != nullptr) { + x_grad[i]->set_lod(x[i]->lod()); + } + } + + int axis = axis_scalar.to(); + + auto out_grad_vec_dims = vectorize(out_grad.dims()); + + axis = funcs::ComputeAxis(axis, out_grad_vec_dims.size()); + + std::vector offset(out_grad_vec_dims.size(), 0); + + dnnl::memory::data_type out_grad_type = + funcs::ToOneDNNDataType(out_grad.dtype()); + funcs::ReorderOneDNNHandler reorder_handler( + out_grad_vec_dims, out_grad.dtype(), out_grad_type, onednn_engine); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + out_grad.mem_desc(), funcs::to_void_cast(out_grad.data())); + + for (size_t i = 0; i < x_grad.size(); ++i) { + if (x_grad[i]->numel() != 0UL) { + auto x_grad_vec_dims = vectorize(x_grad[i]->dims()); + auto slice_mem_p = reorder_handler.AcquireSubmemory( + x_grad_vec_dims, offset, reorder_src_memory_p); + + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + x_grad[i], + x_grad_vec_dims, + funcs::GetPlainOneDNNFormat(x_grad_vec_dims.size()), + dev_ctx.GetPlace()); + auto reorder_p = + reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); + + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + + offset[axis] += x_grad[i]->dims()[axis]; + + x_grad[i]->set_mem_desc(reorder_dst_memory_p->get_desc()); + } + } + astream.wait(); +} +} // namespace phi + +PD_REGISTER_KERNEL(concat_grad, + OneDNN, + ALL_LAYOUT, + phi::ConcatGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/concat_kernel.cc b/paddle/phi/kernels/onednn/concat_kernel.cc new file mode 100644 index 00000000000..11ddc1884db --- /dev/null +++ b/paddle/phi/kernels/onednn/concat_kernel.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { +using memory = dnnl::memory; + +namespace funcs { + +template +class ConcatOneDNNHandler : public OneDNNHandlerNoCachingT { + public: + ConcatOneDNNHandler(Place cpu_place, + int concat_axis, + const dnnl::engine onednn_engine, + const std::vector& inputs, + DenseTensor* output) + : OneDNNHandlerNoCachingT(onednn_engine, cpu_place) { + const int rank = inputs[0]->dims().size(); + + PADDLE_ENFORCE_EQ( + concat_axis >= -rank && concat_axis < rank, + true, + errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, + rank, + concat_axis)); + + if (concat_axis < 0) { + concat_axis = concat_axis + rank; + } + + memory::data_type dt = ToOneDNNDataType(inputs[0]->dtype()); + std::vector srcs_md; + srcs_md.reserve(inputs.size()); + + // Create memory descriptors for each of inputs + for (size_t i = 0; i < inputs.size(); ++i) { + srcs_md.push_back(inputs[i]->mem_desc()); + } + + auto dst_dims = vectorize(output->dims()); + + memory::desc dst_md = memory::desc(dst_dims, dt, OneDNNMemoryFormat::any); + + this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); + } + + // (jczaja) concat oneDNN prim is not having .desc attribute so + // we cannot use base AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const memory::desc& dst_md, + const int concat_axis, + const std::vector& srcs_md) { + this->fwd_pd_.reset(new dnnl::concat::primitive_desc( + dst_md, concat_axis, srcs_md, this->engine_)); + } + + std::shared_ptr AcquireSrcMemory(const DenseTensor& input, + int i) { + const T* input_data = input.data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data)); + } +}; +} // namespace funcs + +static void EnforceLayouts(const std::vector inputs) { + for (auto* input : inputs) { + PADDLE_ENFORCE_EQ( + input->layout(), + DataLayout::ONEDNN, + errors::InvalidArgument("Wrong layout set for Input tensor")); + } +} + +// From a multi-input, gather only nonempty inputs +static const std::vector ReduceMultiInput( + const std::vector& inputs) { + std::vector reduced(inputs.size()); + auto end_it = std::copy_if( + inputs.begin(), inputs.end(), reduced.begin(), [](const DenseTensor* t) { + return t->numel() > 0; + }); + reduced.resize(std::distance(reduced.begin(), end_it)); + return reduced; +} + +template +void ConcatKernel(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + // If any of the multiple inputs of concat has an input size of 0, the + // actual size of the multi_input will change + auto multi_input = ReduceMultiInput(x); + EnforceLayouts(multi_input); + + auto out_dims_vec = vectorize(out->dims()); + if (std::any_of(out_dims_vec.begin(), out_dims_vec.end(), [](int64_t i) { + return i < 0; + })) { + std::vector x_dims; + x_dims.reserve(x.size()); + for (size_t i = 0; i < x.size(); ++i) { + x_dims.push_back(x[i]->dims()); + } + + DDim out_dims = + funcs::ComputeAndCheckShape(true, x_dims, axis.to()); + out->Resize(out_dims); + } + + funcs::ConcatOneDNNHandler handler( + dev_ctx.GetPlace(), axis.to(), onednn_engine, multi_input, out); + + std::vector> srcs; + srcs.reserve(multi_input.size()); + + auto dst_mem = handler.AcquireDstMemory(out); + auto concat_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + std::unordered_map args; + for (size_t i = 0; i < multi_input.size(); ++i) { + srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); + args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); + } + args.insert({DNNL_ARG_DST, *dst_mem}); + + concat_p->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_mem->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(concat, + OneDNN, + ALL_LAYOUT, + phi::ConcatKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} diff --git a/paddle/phi/kernels/onednn/expand_grad_kernel.cc b/paddle/phi/kernels/onednn/expand_grad_kernel.cc new file mode 100644 index 00000000000..dd8afdd8467 --- /dev/null +++ b/paddle/phi/kernels/onednn/expand_grad_kernel.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/expand_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void ExpandGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const IntArray& shape, + DenseTensor* in_grad) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto in_grad_vec_dims = vectorize(in_grad->dims()); + auto out_grad_vec_dims = vectorize(out_grad.dims()); + + if (in_grad_vec_dims.size() != out_grad_vec_dims.size()) { + in_grad_vec_dims.insert(in_grad_vec_dims.begin(), + out_grad_vec_dims.size() - in_grad_vec_dims.size(), + 1); + } + + auto& astream = OneDNNContext::tls().get_stream(); + if (out_grad_vec_dims == in_grad_vec_dims) { + dnnl::memory::data_type out_grad_type = + funcs::ToOneDNNDataType(out_grad.dtype()); + funcs::ReorderOneDNNHandler reorder_handler( + out_grad_vec_dims, out_grad.dtype(), out_grad_type, onednn_engine); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + out_grad.mem_desc(), funcs::to_void_cast(out_grad.data())); + + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + in_grad, + funcs::GetPlainOneDNNFormat(in_grad_vec_dims.size()), + dev_ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + in_grad->set_mem_desc(reorder_dst_memory_p->get_desc()); + } else { + funcs::ReductionOneDNNHandler handler(dnnl::algorithm::reduction_sum, + 0.0f, + 0.0f, + onednn_engine, + dev_ctx.GetPlace(), + &out_grad, + in_grad, + in_grad_vec_dims); + + auto src_memory_p = handler.AcquireSrcMemory(&out_grad); + auto dst_memory_p = handler.AcquireDstMemory(in_grad); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); + astream.wait(); + in_grad->set_layout(DataLayout::ONEDNN); + in_grad->set_mem_desc( + dst_memory_p->get_desc().reshape(vectorize(in_grad->dims()))); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(expand_grad, + OneDNN, + ALL_LAYOUT, + phi::ExpandGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/expand_kernel.cc b/paddle/phi/kernels/onednn/expand_kernel.cc new file mode 100644 index 00000000000..52d12bb100d --- /dev/null +++ b/paddle/phi/kernels/onednn/expand_kernel.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/expand_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +std::vector GetExtendedXDims(const std::vector& x_vec_dims, + int new_size) { + std::vector extended_x_dims(new_size, 1); + std::copy(x_vec_dims.begin(), + x_vec_dims.end(), + extended_x_dims.begin() + new_size - x_vec_dims.size()); + + return extended_x_dims; +} + +template +void ExpandKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto x_vec_dims = vectorize(x.dims()); + + auto out_new_dims = shape.GetData(); + + for (size_t i = 0; i < out_new_dims.size(); ++i) { + out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i]; + } + + if (x_vec_dims.size() != out_new_dims.size()) { + x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size()); + } + + out->Resize(make_ddim(out_new_dims)); + funcs::BroadcastDataOneDNNHandler handler(dnnl::algorithm::binary_add, + onednn_engine, + dev_ctx.GetPlace(), + &x, + out, + 0.0f, + 1.0f, + x_vec_dims); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto dst_memory_p = handler.AcquireZeroedDstMemory(out); + auto binary_p = handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *dst_memory_p}, + {DNNL_ARG_SRC_1, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = OneDNNContext::tls().get_stream(); + binary_p->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(expand, + OneDNN, + ALL_LAYOUT, + phi::ExpandKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/full_kernel.cc b/paddle/phi/kernels/onednn/full_kernel.cc new file mode 100644 index 00000000000..5a444175bfb --- /dev/null +++ b/paddle/phi/kernels/onednn/full_kernel.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/full_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +namespace funcs { + +template +class FillConstantOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + FillConstantOneDNNHandler(DenseTensor* out, + dnnl::engine engine, + Place cpu_place) + : OneDNNHandlerNoCachingT(engine, cpu_place) { + const auto src0_md = dnnl::memory::desc({out->numel(), sizeof(T)}, + OneDNNGetDataType(), + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attrs; + attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f}); + + this->AcquireForwardPrimitiveDescriptor( + attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md); + } + + static const dnnl::memory::desc src1_md; +}; + +template +const dnnl::memory::desc FillConstantOneDNNHandler::src1_md( + {1, sizeof(T)}, OneDNNGetDataType(), dnnl::memory::format_tag::ab); +} // namespace funcs + +template +void FullKernel(const Context& dev_ctx, + const IntArray& shape, + const Scalar& val, + DataType dtype, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + T fill_value = val.to(); + out->Resize(make_ddim(shape.GetData())); + + funcs::FillConstantOneDNNHandler handler( + out, onednn_engine, dev_ctx.GetPlace()); + + dnnl::memory constant_value_memory = + dnnl::memory(funcs::FillConstantOneDNNHandler::src1_md, + onednn_engine, + reinterpret_cast(&fill_value)); + + auto src0_memory_p = handler.AcquireDstMemory(out); + auto fill_constant_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + fill_constant_p->execute(astream, + {{DNNL_ARG_SRC_0, *src0_memory_p}, + {DNNL_ARG_SRC_1, constant_value_memory}, + {DNNL_ARG_DST, *src0_memory_p}}); + astream.wait(); + + // src0_memory_p's md was just to allow the usage of a binary + // primitive as a memset, and now we need to create a real one + out->set_mem_desc({vectorize(out->dims()), + funcs::OneDNNGetDataType(), + funcs::GetPlainOneDNNFormat(out->dims().size())}); +} +} // namespace phi + +PD_REGISTER_KERNEL(full, OneDNN, ALL_LAYOUT, phi::FullKernel, float) {} diff --git a/paddle/phi/kernels/onednn/interpolate_kernel.cc b/paddle/phi/kernels/onednn/interpolate_kernel.cc new file mode 100644 index 00000000000..f70b9bcaf1a --- /dev/null +++ b/paddle/phi/kernels/onednn/interpolate_kernel.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/interpolate_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/interpolate_function.h" + +namespace phi { + +namespace funcs { +template +class InterpolateOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + InterpolateOneDNNHandler(const dnnl::algorithm algo, + const dnnl::engine engine, + Place cpu_place, + const DenseTensor* x, + DenseTensor* out) + : OneDNNHandlerNoCachingT(engine, + cpu_place) { + const auto dst_tz = vectorize(out->dims()); + const auto dst_md = dnnl::memory::desc( + dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_inference, algo, x->mem_desc(), dst_md); + } +}; +} // namespace funcs + +std::vector ComputeOutputShape( + const DenseTensor* x, + const paddle::optional& out_size, + const paddle::optional>& size_tensor, + const paddle::optional& scale_tensor, + const std::string& data_layout, + int out_d, + int out_h, + int out_w, + const std::vector& scale_attr) { + const auto& in_dims = x->dims(); + const DDim in_dhw_dims = slice_ddim(in_dims, 2, in_dims.size()); + + std::vector out_dims; + out_dims.reserve(5); + if (in_dhw_dims.size() == 1) { + out_dims.push_back(out_w); + } else if (in_dhw_dims.size() == 2) { + out_dims.push_back(out_h); + out_dims.push_back(out_w); + } else if (in_dhw_dims.size() == 3) { + out_dims.push_back(out_d); + out_dims.push_back(out_h); + out_dims.push_back(out_w); + } + + if (size_tensor && size_tensor.get().size() > 0) { + auto new_size = funcs::get_new_shape(size_tensor.get()); + if (new_size.size() == out_dims.size()) { + out_dims = new_size; + } + } else if (out_size) { + auto out_size_data = + funcs::get_new_data_from_tensor(out_size.get_ptr()); + if (out_size_data.size() == out_dims.size()) { + out_dims = out_size_data; + } + } else { + std::vector scale; + scale.reserve(3); + if (scale_tensor) { + auto scale_data = + funcs::get_new_data_from_tensor(scale_tensor.get_ptr()); + scale.resize(3, scale_data[0]); + std::copy(scale_data.begin(), scale_data.end(), scale.begin()); + } else { + if (scale_attr.size() > 0) { + scale.resize(3, scale_attr[0]); + std::copy(scale_attr.begin(), scale_attr.end(), scale.begin()); + } + } + + if (scale.size() == 3 && scale[0] > 0.0f && scale[1] > 0.0f && + scale[2] > 0.0f) { + int j = 0; + std::vector in_dhw_vec = vectorize(in_dhw_dims); + std::transform( + in_dhw_vec.begin(), + in_dhw_vec.end(), + out_dims.begin(), + [&](int64_t i) -> int { return static_cast(i * scale[j++]); }); + } + } + + PADDLE_ENFORCE_GT( + std::all_of( + out_dims.begin(), out_dims.end(), [](int i) { return i > 0; }), + 0, + errors::InvalidArgument("out_d, out_h, out_w of Op(interpolate) " + "should be greater than 0.")); + + const std::vector nc_dims = {in_dims[0], in_dims[1]}; + out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end()); + return out_dims; +} + +template +void InterpolateKernel( + const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& out_size, + const paddle::optional>& size_tensor, + const paddle::optional& scale_tensor, + const std::string& data_layout, + int out_d, + int out_h, + int out_w, + const std::vector& scale, + const std::string& interp_method, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + const dnnl::algorithm algo = (interp_method == "nearest") + ? dnnl::algorithm::resampling_nearest + : dnnl::algorithm::resampling_linear; + + const auto out_dims_vec = ComputeOutputShape(&x, + out_size, + size_tensor, + scale_tensor, + data_layout, + out_d, + out_h, + out_w, + scale); + DDim dim_out = make_ddim(out_dims_vec); + out->Resize(dim_out); + + funcs::InterpolateOneDNNHandler handler( + algo, onednn_engine, dev_ctx.GetPlace(), &x, out); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto dst_memory_p = handler.AcquireDstMemory(out); + + auto resampling_prim = handler.AcquireForwardPrimitive(); + const std::unordered_map args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + auto& astream = OneDNNContext::tls().get_stream(); + + resampling_prim->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} + +template +void BilinearInterpKernel( + const Context& ctx, + const DenseTensor& x, + const paddle::optional& out_size, + const paddle::optional>& size_tensor, + const paddle::optional& scale_tensor, + const std::string& data_layout, + int out_d, + int out_h, + int out_w, + const std::vector& scale, + const std::string& interp_method, + bool align_corners, + int align_mode, + DenseTensor* output) { + InterpolateKernel(ctx, + x, + out_size, + size_tensor, + scale_tensor, + data_layout, + out_d, + out_h, + out_w, + scale, + interp_method, + output); +} + +template +void NearestInterpKernel( + const Context& ctx, + const DenseTensor& x, + const paddle::optional& out_size, + const paddle::optional>& size_tensor, + const paddle::optional& scale_tensor, + const std::string& data_layout, + int out_d, + int out_h, + int out_w, + const std::vector& scale, + const std::string& interp_method, + bool align_corners, + int align_mode, + DenseTensor* output) { + InterpolateKernel(ctx, + x, + out_size, + size_tensor, + scale_tensor, + data_layout, + out_d, + out_h, + out_w, + scale, + interp_method, + output); +} +} // namespace phi + +PD_REGISTER_KERNEL( + bilinear_interp, OneDNN, ALL_LAYOUT, phi::BilinearInterpKernel, float) {} + +PD_REGISTER_KERNEL(nearest_interp, + OneDNN, + ALL_LAYOUT, + phi::NearestInterpKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} -- GitLab