未验证 提交 ad41fce8 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate squeeze and squeeze_grad kernels (#48634)

* squeeze kernel

* squeze fwd

* whitespace
上级 4aad4dc5
......@@ -21,7 +21,6 @@ enum class ReshapeKernelOpName {
reshape,
reshape2,
squeeze,
squeeze2,
flatten,
flatten2,
};
......@@ -106,9 +105,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
case ReshapeKernelOpName::squeeze:
InferShapeSqueezeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeSqueeze2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
......@@ -172,16 +168,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
out_dims = GetOutputShape(axes, x_dims, true);
}
void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims, // NOLINT
framework::DDim& out_dims) const { // NOLINT
auto* out = ctx.Output<phi::DenseTensor>("Out");
auto* xshape = ctx.Output<phi::DenseTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
}
void InferShapeFlattenOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims, // NOLINT
framework::DDim& out_dims) const { // NOLINT
......@@ -342,19 +328,16 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::reshape2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
InferShapeReshape2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
InferShapeReshape2Flatten2GradOp(ctx, x_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
......@@ -369,7 +352,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
dx_dims = dx->dims();
}
void InferShapeReshape2Squeeze2Flatten2GradOp(
void InferShapeReshape2Flatten2GradOp(
const framework::ExecutionContext& ctx,
framework::DDim& dx_dims) const { // NOLINT
auto xshape_dims = ctx.Input<phi::DenseTensor>("XShape")->dims();
......@@ -401,22 +384,6 @@ REGISTER_OP_KERNEL(
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);
REGISTER_OP_KERNEL(
squeeze2,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
squeeze2_grad,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
reshape,
MKLDNN,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& dout,
const IntArray& axes,
DenseTensor* dx) {
auto dout_vec_dims = vectorize(dout.dims());
auto dout_type = funcs::ToOneDNNDataType(dout.dtype());
funcs::ReorderOneDNNHandler reorder_handler(
dout_vec_dims, dout.dtype(), dout_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout.mem_desc(), funcs::to_void_cast(dout.data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx,
funcs::GetPlainOneDNNFormat(dout_vec_dims.size()),
dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
auto dx_dims = slice_ddim(xshape.dims(), 1, xshape.dims().size());
dx->Resize(dx_dims);
reorder_dst_memory_p->get_desc().reshape(vectorize(dx_dims));
}
} // namespace phi
PD_REGISTER_KERNEL(squeeze_grad,
OneDNN,
ONEDNN,
phi::SqueezeGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
namespace phi {
template <typename T, typename Context>
void ExecuteSqueeze(const Context& dev_ctx,
const DenseTensor& x,
const DDim& x_dims,
const DDim& out_dims,
DenseTensor* out) {
auto x_vec_dims = vectorize(x_dims);
funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims,
x.dtype(),
funcs::ToOneDNNDataType(x.dtype()),
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, funcs::GetPlainOneDNNFormat(x_dims.size()), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->Resize(out_dims);
out->set_mem_desc(
reorder_dst_memory_p->get_desc().reshape(vectorize(out_dims)));
}
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
}
template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(squeeze_with_xshape,
OneDNN,
ONEDNN,
phi::SqueezeWithXShapeKernel,
float,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册