From ad294a81fa340f439e75a41ba7c024a85d30b0e6 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 23 Feb 2022 19:12:41 +0800 Subject: [PATCH] [Phi] move flip op to phi kernel (#39822) --- paddle/fluid/operators/flip_op.cc | 13 +-- paddle/fluid/operators/flip_op.cu | 129 ----------------------- paddle/fluid/operators/flip_op.h | 83 --------------- paddle/phi/kernels/cpu/flip_kernel.cc | 77 ++++++++++++++ paddle/phi/kernels/flip_kernel.h | 29 ++++++ paddle/phi/kernels/gpu/flip_kernel.cu | 141 ++++++++++++++++++++++++++ 6 files changed, 250 insertions(+), 222 deletions(-) delete mode 100644 paddle/fluid/operators/flip_op.cu delete mode 100644 paddle/fluid/operators/flip_op.h create mode 100644 paddle/phi/kernels/cpu/flip_kernel.cc create mode 100644 paddle/phi/kernels/flip_kernel.h create mode 100644 paddle/phi/kernels/gpu/flip_kernel.cu diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc index 3f6171b8a0..fc03ef0afa 100644 --- a/paddle/fluid/operators/flip_op.cc +++ b/paddle/fluid/operators/flip_op.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/flip_op.h" #include #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -29,6 +29,7 @@ class FlipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + // TODO move to phi kernel void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ( ctx->HasInput("X"), true, @@ -150,14 +151,6 @@ namespace plat = paddle::platform; REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType, ops::FlipOpGradMaker, ops::FlipOpGradMaker); -REGISTER_OP_CPU_KERNEL( - flip, ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel>, - ops::FlipKernel>); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(flip) diff --git a/paddle/fluid/operators/flip_op.cu b/paddle/fluid/operators/flip_op.cu deleted file mode 100644 index b9f8b16214..0000000000 --- a/paddle/fluid/operators/flip_op.cu +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/flip_op.h" - -#include -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/complex.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -template -__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data, - int64_t* x_shape, int64_t* x_stride, - int* flip_dims, int flip_dims_size, - int total_dims) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - - int cur_indices = idx, rem = 0, dst_offset = 0; - for (int i = 0; i < total_dims; ++i) { - int64_t temp = cur_indices; - cur_indices = cur_indices / x_stride[i]; - rem = temp - cur_indices * x_stride[i]; - // flip the indices if it is in flip_dims - for (int j = 0; j < flip_dims_size; ++j) { - if (i == flip_dims[j]) { - cur_indices = x_shape[i] - 1 - cur_indices; - } - } - dst_offset += cur_indices * x_stride[i]; - cur_indices = rem; - } - out_data[idx] = in_data[dst_offset]; -} - -template -class FlipKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto gplace = ctx.GetPlace(); - auto cplace = platform::CPUPlace(); - auto& dev_ctx = ctx.template device_context(); - - const Tensor* x = ctx.Input("X"); - Tensor* out = ctx.Output("Out"); - auto* in_data = x->data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - auto flip_dims = ctx.template Attr>("axis"); - - const int flip_dims_size = static_cast(flip_dims.size()); - auto x_dims = x->dims(); - const int total_dims = x_dims.size(); - const int N = x->numel(); - - int block_size = 512; - dim3 dim_block(block_size); - dim3 dim_grid((N + block_size - 1) / block_size); - - for (size_t i = 0; i < flip_dims.size(); ++i) { - if (flip_dims[i] < 0) { - flip_dims[i] += total_dims; - } - } - - auto x_stride = phi::stride(x_dims); - std::vector x_dims_v = phi::vectorize(x_dims); - std::vector x_stride_v = phi::vectorize(x_stride); - - int bytes = total_dims * sizeof(int64_t); - auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes); - int64_t* x_strides_array_gpu = - reinterpret_cast(x_strides_array_tmp->ptr()); - memory::Copy(gplace, x_strides_array_gpu, cplace, x_stride_v.data(), bytes, - dev_ctx.stream()); - - auto x_shape_array_tmp = memory::Alloc(dev_ctx, bytes); - int64_t* x_shape_array_gpu = - reinterpret_cast(x_shape_array_tmp->ptr()); - memory::Copy(gplace, x_shape_array_gpu, cplace, x_dims_v.data(), bytes, - dev_ctx.stream()); - - bytes = flip_dims_size * sizeof(int); - auto flip_dims_array_tmp = memory::Alloc(dev_ctx, bytes); - int* flip_dims_array_gpu = - reinterpret_cast(flip_dims_array_tmp->ptr()); - memory::Copy(gplace, flip_dims_array_gpu, cplace, flip_dims.data(), bytes, - dev_ctx.stream()); - - flip_cuda_kernel< - T><<>>( - N, in_data, out_data, x_shape_array_gpu, x_strides_array_gpu, - flip_dims_array_gpu, flip_dims_size, total_dims); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - flip, ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel, - ops::FlipKernel>, - ops::FlipKernel>); diff --git a/paddle/fluid/operators/flip_op.h b/paddle/fluid/operators/flip_op.h deleted file mode 100644 index 3c00df5f67..0000000000 --- a/paddle/fluid/operators/flip_op.h +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -constexpr size_t dim_bitset_size = 64; - -template -class FlipKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override; -}; - -template -class FlipKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* x = ctx.Input("X"); - Tensor* out = ctx.Output("Out"); - auto flip_dims = ctx.template Attr>("axis"); - - auto x_dims = x->dims(); - const int total_dims = x_dims.size(); - std::bitset dim_bitset; - for (size_t i = 0; i < flip_dims.size(); ++i) { - int dim = flip_dims[i]; - if (flip_dims[i] < 0) { - dim += total_dims; - } - dim_bitset[dim] = true; - } - auto x_strides = phi::stride(x_dims); - auto numel = x->numel(); - const T* x_data = x->data(); - T* out_data = out->mutable_data(ctx.GetPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int64_t i = 0; i < numel; ++i) { - int64_t cur_indices = i; - int64_t rem = 0; - int64_t dst_offset = 0; - - for (int d = 0; d < total_dims; ++d) { - int64_t temp = cur_indices; - cur_indices = cur_indices / x_strides[d]; - rem = temp - cur_indices * x_strides[d]; - dst_offset += dim_bitset[d] - ? (x_dims[d] - 1 - cur_indices) * x_strides[d] - : cur_indices * x_strides[d]; - cur_indices = rem; - } - out_data[i] = x_data[dst_offset]; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/flip_kernel.cc b/paddle/phi/kernels/cpu/flip_kernel.cc new file mode 100644 index 0000000000..fa1625d65b --- /dev/null +++ b/paddle/phi/kernels/cpu/flip_kernel.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/flip_kernel.h" + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +constexpr size_t dim_bitset_size = 64; + +template +void FlipKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + auto x_dims = x.dims(); + const int total_dims = x_dims.size(); + std::bitset dim_bitset; + for (size_t i = 0; i < axis.size(); ++i) { + int dim = axis[i]; + if (axis[i] < 0) { + dim += total_dims; + } + dim_bitset[dim] = true; + } + auto x_strides = phi::stride(x_dims); + auto numel = x.numel(); + const T* x_data = x.data(); + T* out_data = dev_ctx.template Alloc(out); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < numel; ++i) { + int64_t cur_indices = i; + int64_t rem = 0; + int64_t dst_offset = 0; + + for (int d = 0; d < total_dims; ++d) { + int64_t temp = cur_indices; + cur_indices = cur_indices / x_strides[d]; + rem = temp - cur_indices * x_strides[d]; + dst_offset += dim_bitset[d] ? (x_dims[d] - 1 - cur_indices) * x_strides[d] + : cur_indices * x_strides[d]; + cur_indices = rem; + } + out_data[i] = x_data[dst_offset]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(flip, + CPU, + ALL_LAYOUT, + phi::FlipKernel, + float, + double, + int32_t, + int64_t, + bool, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/flip_kernel.h b/paddle/phi/kernels/flip_kernel.h new file mode 100644 index 0000000000..4470486fec --- /dev/null +++ b/paddle/phi/kernels/flip_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FlipKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu new file mode 100644 index 0000000000..668d673bd3 --- /dev/null +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -0,0 +1,141 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/flip_kernel.h" + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void flip_cuda_kernel(const int N, + const T* in_data, + T* out_data, + int64_t* x_shape, + int64_t* x_stride, + int* flip_dims, + int flip_dims_size, + int total_dims) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int cur_indices = idx, rem = 0, dst_offset = 0; + for (int i = 0; i < total_dims; ++i) { + int64_t temp = cur_indices; + cur_indices = cur_indices / x_stride[i]; + rem = temp - cur_indices * x_stride[i]; + // flip the indices if it is in flip_dims + for (int j = 0; j < flip_dims_size; ++j) { + if (i == flip_dims[j]) { + cur_indices = x_shape[i] - 1 - cur_indices; + } + } + dst_offset += cur_indices * x_stride[i]; + cur_indices = rem; + } + out_data[idx] = in_data[dst_offset]; +} + +template +void FlipKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + const auto gplace = dev_ctx.GetPlace(); + auto cplace = phi::CPUPlace(); + std::vector flip_dims = axis; + + auto* in_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + + const int flip_dims_size = static_cast(flip_dims.size()); + auto x_dims = x.dims(); + const int total_dims = x_dims.size(); + const int N = x.numel(); + + int block_size = 512; + dim3 dim_block(block_size); + dim3 dim_grid((N + block_size - 1) / block_size); + + for (size_t i = 0; i < flip_dims.size(); ++i) { + if (flip_dims[i] < 0) { + flip_dims[i] += total_dims; + } + } + + auto x_stride = phi::stride(x_dims); + std::vector x_dims_v = phi::vectorize(x_dims); + std::vector x_stride_v = phi::vectorize(x_stride); + + int bytes = total_dims * sizeof(int64_t); + auto x_strides_array_tmp = paddle::memory::Alloc(dev_ctx, bytes); + int64_t* x_strides_array_gpu = + reinterpret_cast(x_strides_array_tmp->ptr()); + paddle::memory::Copy(gplace, + x_strides_array_gpu, + cplace, + x_stride_v.data(), + bytes, + dev_ctx.stream()); + + auto x_shape_array_tmp = paddle::memory::Alloc(dev_ctx, bytes); + int64_t* x_shape_array_gpu = + reinterpret_cast(x_shape_array_tmp->ptr()); + paddle::memory::Copy(gplace, + x_shape_array_gpu, + cplace, + x_dims_v.data(), + bytes, + dev_ctx.stream()); + + bytes = flip_dims_size * sizeof(int); + auto flip_dims_array_tmp = paddle::memory::Alloc(dev_ctx, bytes); + int* flip_dims_array_gpu = reinterpret_cast(flip_dims_array_tmp->ptr()); + paddle::memory::Copy(gplace, + flip_dims_array_gpu, + cplace, + flip_dims.data(), + bytes, + dev_ctx.stream()); + + flip_cuda_kernel<<>>( + N, + in_data, + out_data, + x_shape_array_gpu, + x_strides_array_gpu, + flip_dims_array_gpu, + flip_dims_size, + total_dims); +} +} // namespace phi + +PD_REGISTER_KERNEL(flip, + GPU, + ALL_LAYOUT, + phi::FlipKernel, + float, + double, + phi::dtype::float16, + int, + int64_t, + bool, + phi::dtype::complex, + phi::dtype::complex) {} -- GitLab