diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc index 0c2b439b3e5102545dcd9f904c9e09bc4afbdd03..65a49dab27df2543f26c244b3224643aa094ff4f 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -21,7 +21,6 @@ enum class ReshapeKernelOpName { reshape, reshape2, squeeze, - squeeze2, flatten, flatten2, }; @@ -106,9 +105,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { case ReshapeKernelOpName::squeeze: InferShapeSqueezeOp(ctx, x_dims, out_dims); break; - case ReshapeKernelOpName::squeeze2: - InferShapeSqueeze2Op(ctx, x_dims, out_dims); - break; case ReshapeKernelOpName::flatten: InferShapeFlattenOp(ctx, x_dims, out_dims); break; @@ -172,16 +168,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { out_dims = GetOutputShape(axes, x_dims, true); } - void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx, - framework::DDim& x_dims, // NOLINT - framework::DDim& out_dims) const { // NOLINT - auto* out = ctx.Output("Out"); - auto* xshape = ctx.Output("XShape"); - auto xshape_dims = xshape->dims(); - x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); - out_dims = out->dims(); - } - void InferShapeFlattenOp(const framework::ExecutionContext& ctx, framework::DDim& x_dims, // NOLINT framework::DDim& out_dims) const { // NOLINT @@ -342,19 +328,16 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { InferShapeReshapeSqueezeGradOp(ctx, x_dims); break; case ReshapeKernelOpName::reshape2: - InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); + InferShapeReshape2Flatten2GradOp(ctx, x_dims); break; case ReshapeKernelOpName::squeeze: InferShapeReshapeSqueezeGradOp(ctx, x_dims); break; - case ReshapeKernelOpName::squeeze2: - InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); - break; case ReshapeKernelOpName::flatten: InferShapeFlattenGradOp(ctx, x_dims); break; case ReshapeKernelOpName::flatten2: - InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); + InferShapeReshape2Flatten2GradOp(ctx, x_dims); break; default: PADDLE_THROW(paddle::platform::errors::OutOfRange( @@ -369,7 +352,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { dx_dims = dx->dims(); } - void InferShapeReshape2Squeeze2Flatten2GradOp( + void InferShapeReshape2Flatten2GradOp( const framework::ExecutionContext& ctx, framework::DDim& dx_dims) const { // NOLINT auto xshape_dims = ctx.Input("XShape")->dims(); @@ -401,22 +384,6 @@ REGISTER_OP_KERNEL( ops::ReshapeGradMKLDNNKernel); -REGISTER_OP_KERNEL( - squeeze2, - MKLDNN, - paddle::platform::CPUPlace, - ops::ReshapeMKLDNNKernel, - ops::ReshapeMKLDNNKernel); - -REGISTER_OP_KERNEL( - squeeze2_grad, - MKLDNN, - paddle::platform::CPUPlace, - ops::ReshapeGradMKLDNNKernel, - ops::ReshapeGradMKLDNNKernel); - REGISTER_OP_KERNEL( reshape, MKLDNN, diff --git a/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc b/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..654acfe5700c3162904081b1371fd0c3bc635ae5 --- /dev/null +++ b/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squeeze_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SqueezeGradKernel(const Context& dev_ctx, + const DenseTensor& xshape, + const DenseTensor& dout, + const IntArray& axes, + DenseTensor* dx) { + auto dout_vec_dims = vectorize(dout.dims()); + auto dout_type = funcs::ToOneDNNDataType(dout.dtype()); + + funcs::ReorderOneDNNHandler reorder_handler( + dout_vec_dims, dout.dtype(), dout_type, dev_ctx.GetEngine()); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout.mem_desc(), funcs::to_void_cast(dout.data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + dx, + funcs::GetPlainOneDNNFormat(dout_vec_dims.size()), + dev_ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + auto dx_dims = slice_ddim(xshape.dims(), 1, xshape.dims().size()); + dx->Resize(dx_dims); + reorder_dst_memory_p->get_desc().reshape(vectorize(dx_dims)); +} + +} // namespace phi + +PD_REGISTER_KERNEL(squeeze_grad, + OneDNN, + ONEDNN, + phi::SqueezeGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/squeeze_kernel.cc b/paddle/phi/kernels/onednn/squeeze_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb7663f8e41b2d9177e20b28e9d467fcb5854465 --- /dev/null +++ b/paddle/phi/kernels/onednn/squeeze_kernel.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squeeze_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" + +namespace phi { + +template +void ExecuteSqueeze(const Context& dev_ctx, + const DenseTensor& x, + const DDim& x_dims, + const DDim& out_dims, + DenseTensor* out) { + auto x_vec_dims = vectorize(x_dims); + + funcs::ReorderOneDNNHandler reorder_handler( + x_vec_dims, + x.dtype(), + funcs::ToOneDNNDataType(x.dtype()), + dev_ctx.GetEngine()); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + out->Resize(x_dims); // to match x numel, format is changed later + // reorder is done into a plain tag to allow usage with blocked formats + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, funcs::GetPlainOneDNNFormat(x_dims.size()), dev_ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + out->Resize(out_dims); + out->set_mem_desc( + reorder_dst_memory_p->get_desc().reshape(vectorize(out_dims))); +} + +template +void SqueezeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { + auto x_dims = x.dims(); + std::vector tmp(axes.GetData().begin(), axes.GetData().end()); + auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); + ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); +} + +template +void SqueezeWithXShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { + auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); + auto out_dims = out->dims(); + ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); +} +} // namespace phi + +PD_REGISTER_KERNEL( + squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(squeeze_with_xshape, + OneDNN, + ONEDNN, + phi::SqueezeWithXShapeKernel, + float, + phi::dtype::bfloat16) {}