diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc deleted file mode 100644 index 6a05585a37c6f348e061523a4735c5b6c7b33761..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using SelectedRows = phi::SelectedRows; - -template -class ShapeMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_var = ctx.InputVar("Input"); - framework::DDim in_dims; - if (in_var->IsType()) { - in_dims = in_var->Get().value().dims(); - } else { - in_dims = in_var->Get().dims(); - // Output of shape op is often fed as input to fill_constant ops - // and we need to rotate a shape otherwise Tensors of wrong shape may be - // allocated - if (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() == - framework::DataLayout::kNHWC && - in_dims.size() >= 3) { - auto rdims = phi::vectorize(in_dims); - std::rotate(rdims.begin() + 1, rdims.begin() + 2, rdims.end()); - in_dims = phi::make_ddim(rdims); - } - } - auto* out_t = ctx.Output("Out"); - out_t->Resize({in_dims.size()}); - auto out_data = out_t->mutable_data(platform::CPUPlace()); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } - - dnnl::memory::desc out_mem_desc( - phi::vectorize(out_t->dims()), - framework::ToMKLDNNDataType( - framework::TransToProtoVarType(out_t->dtype())), - platform::GetPlainMKLDNNFormat(out_t->dims().size())); - - out_t->set_mem_desc(out_mem_desc); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(shape, - MKLDNN, - paddle::platform::CPUPlace, - ops::ShapeMKLDNNKernel, - ops::ShapeMKLDNNKernel, - ops::ShapeMKLDNNKernel, - ops::ShapeMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc index 18c3e40280a2a4bfa49b7470aaf7562578e85b68..8066c2d86117503939fdddf0332b0606c4674e92 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc @@ -34,7 +34,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT); USE_OP_ITSELF(transpose); USE_OP_DEVICE_KERNEL(transpose, MKLDNN); USE_OP_ITSELF(shape); -USE_OP_DEVICE_KERNEL(shape, MKLDNN); +PD_DECLARE_KERNEL(shape, OneDNN, ALL_LAYOUT); USE_OP_ITSELF(crop); USE_OP_DEVICE_KERNEL(crop, CPU); diff --git a/paddle/phi/kernels/onednn/shape_kernel.cc b/paddle/phi/kernels/onednn/shape_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6fcd32f1c81a5de2d1ee57704ae69619bf789b3 --- /dev/null +++ b/paddle/phi/kernels/onednn/shape_kernel.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/shape_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + DDim x_dims = x.dims(); + + // Output of shape op is often fed as x to fill_constant ops + // and we need to rotate a shape otherwise Tensors of wrong shape may be + // allocated + if (OneDNNContext::tls().get_cur_paddle_data_layout() == DataLayout::kNHWC && + x_dims.size() >= 3) { + auto rdims = vectorize(x_dims); + std::rotate(rdims.begin() + 1, rdims.begin() + 2, rdims.end()); + x_dims = make_ddim(rdims); + } + + out->Resize({x_dims.size()}); + auto out_data = dev_ctx.template Alloc(out); + for (int i = 0; i < x_dims.size(); ++i) { + out_data[i] = x_dims[i]; + } + + dnnl::memory::desc out_mem_desc( + vectorize(out->dims()), + funcs::ToOneDNNDataType(out->dtype()), + funcs::GetPlainOneDNNFormat(out->dims().size())); + out->set_mem_desc(out_mem_desc); +} +} // namespace phi + +PD_REGISTER_KERNEL(shape, + OneDNN, + ALL_LAYOUT, + phi::ShapeKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {}