diff --git a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc deleted file mode 100644 index 1e546e44fa241658718d38c81a5e0c7ae9cf80ba..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -namespace paddle { -namespace operators { - -using dnnl::concat; -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using framework::DataLayout; -using framework::LoDTensor; -using framework::Tensor; -using platform::to_void_cast; - -template -class StackMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - StackMKLDNNHandler(const framework::ExecutionContext& ctx, - const dnnl::engine mkldnn_engine, - const std::vector& inputs, - Tensor* output) - : platform::MKLDNNHandlerNoCachingT(mkldnn_engine, - ctx.GetPlace()) { - int stack_axis = ctx.Attr("axis"); - - int ndims = inputs[0]->dims().size(); - - if (stack_axis < 0) { - stack_axis = ndims + 1 + stack_axis; // +1 to match output's ndims - } - - // in stack op all inputs must have same dims - auto input_dims = phi::vectorize(inputs[0]->dims()); - - memory::data_type dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(inputs[0]->dtype())); - std::vector srcs_md; - memory::desc dst_md; - MKLDNNMemoryFormat dst_fmt; - - srcs_md.reserve(inputs.size()); - - // if stack is not done on last(non existing) axis, then we can optimize - // concat primitive by not adding additional dimension, since it causes - // wrong output format deduction and suboptimal performance as a result - if (stack_axis != ndims) { - for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.push_back(inputs[i]->mem_desc()); - } - - input_dims[stack_axis] *= inputs.size(); - dst_md = memory::desc(input_dims, dt, MKLDNNMemoryFormat::any); - } else { - auto extended_input_dims = phi::vectorize(output->dims()); - extended_input_dims[stack_axis] = 1; - - for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims)); - } - - // concat primitive choses suboptimal format tag because it cannot - // distinguish between f.e. abcd and abdc if last dim is equal to 1 so - // enforcing is needed for better performance - dst_fmt = platform::GetPlainMKLDNNFormat(extended_input_dims.size()); - dst_md = memory::desc(phi::vectorize(output->dims()), dt, dst_fmt); - } - - this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md); - } - - // concat oneDNN prim is not having .desc attribute so we cannot use default - // AcquireForwardPrimitiveDescriptor - void AcquireForwardPrimitiveDescriptor( - const memory::desc& dst_md, - const int stack_axis, - const std::vector& srcs_md) { - this->fwd_pd_.reset(new dnnl::concat::primitive_desc( - dst_md, stack_axis, srcs_md, this->engine_)); - } - - std::shared_ptr AcquireSrcMemory(const Tensor& input, int i) { - const T* input_data = input.data(); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), - to_void_cast(input_data)); - } -}; - -template -class StackMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto multi_input = ctx.MultiInput("X"); - - Tensor* output = ctx.Output("Y"); - - StackMKLDNNHandler handler(ctx, mkldnn_engine, multi_input, output); - - std::vector> srcs; - srcs.reserve(multi_input.size()); - - auto dst_mem = handler.AcquireDstMemory(output); - auto concat_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - std::unordered_map args; - for (size_t i = 0; i < multi_input.size(); ++i) { - srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); - args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); - } - args.insert({DNNL_ARG_DST, *dst_mem}); - - concat_p->execute(astream, args); - astream.wait(); - - output->set_mem_desc( - dst_mem->get_desc().reshape(phi::vectorize(output->dims()))); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(stack, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::StackMKLDNNOpKernel); diff --git a/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc b/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc deleted file mode 100644 index e332972f7576abd70de442a4a0a68eec86a572ec..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/mkldnn/axpy_handler.h" -#include "paddle/fluid/operators/optimizers/sgd_op.h" - -namespace pplat = paddle::platform; - -namespace paddle { -namespace operators { - -template -class SGDOneDNNKernel : public SGDOpKernel { - protected: - void dense_param_and_grad_kernel( - const framework::ExecutionContext &ctx) const override { - VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel"; - const auto *learning_rate = ctx.Input("LearningRate"); - const auto *param = ctx.Input("Param"); - auto *param_out = ctx.Output("ParamOut"); - const auto *grad = ctx.Input("Grad"); - - auto *out_data = param_out->mutable_data(ctx.GetPlace()); - const T *param_data = param->data(); - const auto *grad_data = grad->data(); - const auto *lr = learning_rate->data(); - // Since denese SGD is not in place operation, first copy params to output - // tensor and then update it. - std::memcpy(out_data, param_data, param->memory_size()); - OneDNNAXPYHandler(param_out->numel(), -lr[0])(grad_data, out_data); - } - - void dense_param_sparse_grad_kernel( - const framework::ExecutionContext &ctx) const override { - VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel"; - const auto *learning_rate = ctx.Input("LearningRate"); - auto *param_out = ctx.Output("ParamOut"); - const auto *grad = ctx.Input("Grad"); - - const auto &grad_value = grad->value(); - const auto &grad_rows = grad->rows(); - const auto grad_height = grad->height(); - const int64_t grad_val_height = static_cast(grad_rows.size()); - const auto grad_width = grad_value.numel() / grad_val_height; - - const auto *grad_data = grad_value.data(); - auto *out_data = param_out->data(); - const auto *lr = learning_rate->data(); - - OneDNNAXPYHandler axpy_handler(grad_width, -lr[0]); - - for (size_t i = 0; i < grad_rows.size(); ++i) { - PADDLE_ENFORCE_LT( - grad_rows[i], - grad_height, - pplat::errors::OutOfRange( - "Grad rows index value should be less than grad height." - "Got [%s], but expected less than [%s]", - grad_rows[i], - grad_height)); - const int64_t row = grad_rows[i]; - const auto *src = grad_data + i * grad_width; - auto *dst = out_data + row * grad_width; - axpy_handler(src, dst); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(sgd, - MKLDNN, - pplat::CPUPlace, - ops::SGDOneDNNKernel, - ops::SGDOneDNNKernel); diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 9a26aed5f341b57c6813b38b0c1eebc1ab0d9001..9bc9573529241ec8c84aeec25cd1c0a7a0203b6c 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -21,6 +21,7 @@ endif() if(WITH_MKLDNN) list(APPEND BACKENDS_SRCS onednn/onednn_context.cc) + list(APPEND BACKENDS_SRCS onednn/axpy_handler.cc) list(APPEND BACKENDS_DEPS mkldnn) endif() diff --git a/paddle/phi/backends/onednn/axpy_handler.cc b/paddle/phi/backends/onednn/axpy_handler.cc new file mode 100644 index 0000000000000000000000000000000000000000..df61948d62215b0d95db7fd18fbbc9c40a8c3d10 --- /dev/null +++ b/paddle/phi/backends/onednn/axpy_handler.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/onednn/axpy_handler.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/onednn/onednn_helper.h" + +namespace phi { +namespace funcs { + +template +class AXPYHandler { + public: + AXPYHandler(const dnnl::engine onednn_engine, int n, float alpha) { + OneDNNContext::tls().log_lib_version(); + auto md = dnnl::memory::desc( + {n}, OneDNNGetDataType(), dnnl::memory::format_tag::x); + src_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); + dst_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); + dnnl::primitive_attr reorder_attr; + dnnl::post_ops post_operations; + if (alpha != 1.f) { + std::vector scales(1, alpha); + reorder_attr.set_output_scales(0, scales); + } + post_operations.append_sum(1.0f); + + reorder_attr.set_post_ops(post_operations); + reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr); + } + + dnnl::memory &AcquireSrcMemory(const T *x) { + src_mem_.set_data_handle(to_void_cast(x)); + return src_mem_; + } + + dnnl::memory &AcquireDstMemory(T *y) { + dst_mem_.set_data_handle(y); + return dst_mem_; + } + + const dnnl::reorder &AcquireReorder() { return reorder_p_; } + + private: + dnnl::memory src_mem_; + dnnl::memory dst_mem_; + dnnl::reorder reorder_p_; +}; + +template class AXPYHandler; +template class AXPYHandler; + +template +static void naive_axpy(int n, T alpha, const T *x, T *y) { + while (n-- > 0) { + *y += alpha * *x; + ++y; + ++x; + } +} + +template +class OneDNNAXPYHandler::Impl { + public: + Impl(int64_t n, T alpha, const dnnl::engine onednn_engine); + void operator()(const T *x, T *y); + + private: + std::unique_ptr> handler_; + int64_t n_; + T alpha_; +}; + +template +OneDNNAXPYHandler::Impl::Impl(int64_t n, + T alpha, + const dnnl::engine onednn_engine) + : n_{n}, alpha_{alpha} { + handler_ = std::make_unique>( + onednn_engine, n, static_cast(alpha)); +} + +template +void OneDNNAXPYHandler::Impl::operator()(const T *x, T *y) { + if (this->n_ < 100) { + naive_axpy(this->n_, this->alpha_, x, y); + return; + } + + auto &reorder_src_mem_p = handler_->AcquireSrcMemory(x); + auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y); + auto reorder_p = handler_->AcquireReorder(); + auto &astream = OneDNNContext::tls().get_stream(); + reorder_p.execute(astream, reorder_src_mem_p, reorder_dst_mem_p); + astream.wait(); +} + +template +OneDNNAXPYHandler::OneDNNAXPYHandler(int64_t n, + T alpha, + const dnnl::engine onednn_engine) + : pimpl_{new Impl{n, alpha, onednn_engine}, + [](Impl *impl) { delete impl; }} { + VLOG(4) << "[OneDNN] OneDNNAXPYHandler<" << typeid(T).name() << ">, " + << "n: " << n << ", alpha: " << alpha; +} + +template +void OneDNNAXPYHandler::operator()(const T *x, T *y) { + pimpl_->operator()(x, y); +} + +template class OneDNNAXPYHandler; +template class OneDNNAXPYHandler; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/backends/onednn/axpy_handler.h b/paddle/phi/backends/onednn/axpy_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..dd9a8108f59b05fb27fec41026961aba60dc167b --- /dev/null +++ b/paddle/phi/backends/onednn/axpy_handler.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "dnnl.hpp" // NOLINT + +namespace phi { +namespace funcs { +/// +/// @brief Helper class for AXPY execution using oneDNN library. +/// +/// @tparam T Data type. +/// +template +class OneDNNAXPYHandler { + public: + OneDNNAXPYHandler(OneDNNAXPYHandler&) = delete; + OneDNNAXPYHandler(OneDNNAXPYHandler&&) = delete; + OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&) = delete; + OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&&) = delete; + /// + /// @brief Constructor. + /// + /// @param[in] n The number of elements in tensor (assumed 1D + /// tensor) + /// @param[in] alpha The alpha coefficient. + /// @param[in] onednn_engine The oneDNN engine. + /// + OneDNNAXPYHandler(int64_t n, T alpha, dnnl::engine onednn_engine); + /// + /// @brief Executes AXPY. + /// + /// @param[in] x The pointer to input X tensor data. + /// @param[out] y The pointer to output Y tensor data. + /// + void operator()(const T* x, T* y); + + private: + OneDNNAXPYHandler() = delete; + // (arogowie-intel) Private implementation idiom to hide dependency + // on OneDNN headers. + class Impl; + // We need custom deleter, since the compiler is unable to parameterize + // an allocator's default deleter due to incomple type. + std::unique_ptr pimpl_; +}; +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/onednn/sgd_kernel.cc b/paddle/phi/kernels/onednn/sgd_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbb02204105d16afd40acbd58003062f246f0fa3 --- /dev/null +++ b/paddle/phi/kernels/onednn/sgd_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" + +#include "paddle/phi/backends/onednn/axpy_handler.h" +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + auto* out_data = dev_ctx.template Alloc(param_out); + const T* param_data = param.data(); + const auto* grad_data = grad.data(); + const auto* lr = learning_rate.data(); + // Since denese SGD is not in place operation, first copy params to output + // tensor and then update it. + std::memcpy(out_data, param_data, param.memory_size()); + funcs::OneDNNAXPYHandler(param_out->numel(), -lr[0], dev_ctx.GetEngine())( + grad_data, out_data); +} + +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const auto grad_height = grad.height(); + const int64_t grad_val_height = static_cast(grad_rows.size()); + const auto grad_width = grad_value.numel() / grad_val_height; + + const auto* grad_data = grad_value.data(); + auto* out_data = param_out->data(); + const auto* lr = learning_rate.data(); + + funcs::OneDNNAXPYHandler axpy_handler( + grad_width, -lr[0], dev_ctx.GetEngine()); + + for (size_t i = 0; i < grad_rows.size(); ++i) { + PADDLE_ENFORCE_LT( + grad_rows[i], + grad_height, + errors::OutOfRange( + "Grad rows index value should be less than grad height." + "Got [%s], but expected less than [%s]", + grad_rows[i], + grad_height)); + const int64_t row = grad_rows[i]; + const auto* src = grad_data + i * grad_width; + auto* dst = out_data + row * grad_width; + axpy_handler(src, dst); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + sgd, OneDNN, ALL_LAYOUT, phi::SGDDenseKernel, float, phi::dtype::bfloat16) { +} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + OneDNN, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/stack_kernel.cc b/paddle/phi/kernels/onednn/stack_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ede31952e88d1caa3e3ca32ec6806151d74a4dd --- /dev/null +++ b/paddle/phi/kernels/onednn/stack_kernel.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/stack_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +namespace funcs { +template +class StackOneDNNHandler : public OneDNNHandlerNoCachingT { + public: + StackOneDNNHandler(const Place& cpu_place, + int stack_axis, + const dnnl::engine onednn_engine, + const std::vector& inputs, + DenseTensor* output) + : OneDNNHandlerNoCachingT(onednn_engine, cpu_place) { + int ndims = inputs[0]->dims().size(); + + if (stack_axis < 0) { + stack_axis = ndims + 1 + stack_axis; // +1 to match output's ndims + } + + // in stack op all inputs must have same dims + auto input_dims = vectorize(inputs[0]->dims()); + + dnnl::memory::data_type dt = ToOneDNNDataType(inputs[0]->dtype()); + std::vector srcs_md; + dnnl::memory::desc dst_md; + OneDNNMemoryFormat dst_fmt; + + srcs_md.reserve(inputs.size()); + + // if stack is not done on last(non existing) axis, then we can optimize + // concat primitive by not adding additional dimension, since it causes + // wrong output format deduction and suboptimal performance as a result + if (stack_axis != ndims) { + for (size_t i = 0; i < inputs.size(); ++i) { + srcs_md.push_back(inputs[i]->mem_desc()); + } + + input_dims[stack_axis] *= inputs.size(); + dst_md = dnnl::memory::desc(input_dims, dt, OneDNNMemoryFormat::any); + } else { + auto extended_input_dims = vectorize(output->dims()); + extended_input_dims[stack_axis] = 1; + + for (size_t i = 0; i < inputs.size(); ++i) { + srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims)); + } + + // concat primitive choses suboptimal format tag because it cannot + // distinguish between f.e. abcd and abdc if last dim is equal to 1 so + // enforcing is needed for better performance + dst_fmt = GetPlainOneDNNFormat(extended_input_dims.size()); + dst_md = dnnl::memory::desc(vectorize(output->dims()), dt, dst_fmt); + } + + this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md); + } + + // concat oneDNN prim is not having .desc attribute so we cannot use default + // AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const memory::desc& dst_md, + const int stack_axis, + const std::vector& srcs_md) { + this->fwd_pd_.reset(new dnnl::concat::primitive_desc( + dst_md, stack_axis, srcs_md, this->engine_)); + } + + std::shared_ptr AcquireSrcMemory(const DenseTensor& input, + int i) { + const T* input_data = input.data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data)); + } +}; +} // namespace funcs + +template +void StackKernel(const Context& dev_ctx, + const std::vector& multi_input, + int axis, + DenseTensor* output) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + funcs::StackOneDNNHandler handler( + dev_ctx.GetPlace(), axis, onednn_engine, multi_input, output); + + std::vector> srcs; + srcs.reserve(multi_input.size()); + + auto dst_mem = handler.AcquireDstMemory(output); + auto concat_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + std::unordered_map args; + for (size_t i = 0; i < multi_input.size(); ++i) { + srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); + args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); + } + args.insert({DNNL_ARG_DST, *dst_mem}); + + concat_p->execute(astream, args); + astream.wait(); + + output->set_mem_desc(dst_mem->get_desc().reshape(vectorize(output->dims()))); +} + +} // namespace phi + +PD_REGISTER_KERNEL(stack, OneDNN, ALL_LAYOUT, phi::StackKernel, float) {}