未验证 提交 ad294a81 编写于 作者: Y Yang 提交者: GitHub

[Phi] move flip op to phi kernel (#39822)

上级 64ed92bd
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flip_op.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -29,6 +29,7 @@ class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
// TODO move to phi kernel
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
......@@ -150,14 +151,6 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
flip, ops::FlipKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flip_op.h"
#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T>
__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data,
int64_t* x_shape, int64_t* x_stride,
int* flip_dims, int flip_dims_size,
int total_dims) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int cur_indices = idx, rem = 0, dst_offset = 0;
for (int i = 0; i < total_dims; ++i) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_stride[i];
rem = temp - cur_indices * x_stride[i];
// flip the indices if it is in flip_dims
for (int j = 0; j < flip_dims_size; ++j) {
if (i == flip_dims[j]) {
cur_indices = x_shape[i] - 1 - cur_indices;
}
}
dst_offset += cur_indices * x_stride[i];
cur_indices = rem;
}
out_data[idx] = in_data[dst_offset];
}
template <typename T>
class FlipKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto gplace = ctx.GetPlace();
auto cplace = platform::CPUPlace();
auto& dev_ctx = ctx.template device_context<CUDADeviceContext>();
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x->dims();
const int total_dims = x_dims.size();
const int N = x->numel();
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims.size(); ++i) {
if (flip_dims[i] < 0) {
flip_dims[i] += total_dims;
}
}
auto x_stride = phi::stride(x_dims);
std::vector<int64_t> x_dims_v = phi::vectorize(x_dims);
std::vector<int64_t> x_stride_v = phi::vectorize(x_stride);
int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes);
int64_t* x_strides_array_gpu =
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_stride_v.data(), bytes,
dev_ctx.stream());
auto x_shape_array_tmp = memory::Alloc(dev_ctx, bytes);
int64_t* x_shape_array_gpu =
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
memory::Copy(gplace, x_shape_array_gpu, cplace, x_dims_v.data(), bytes,
dev_ctx.stream());
bytes = flip_dims_size * sizeof(int);
auto flip_dims_array_tmp = memory::Alloc(dev_ctx, bytes);
int* flip_dims_array_gpu =
reinterpret_cast<int*>(flip_dims_array_tmp->ptr());
memory::Copy(gplace, flip_dims_array_gpu, cplace, flip_dims.data(), bytes,
dev_ctx.stream());
flip_cuda_kernel<
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
N, in_data, out_data, x_shape_array_gpu, x_strides_array_gpu,
flip_dims_array_gpu, flip_dims_size, total_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
flip, ops::FlipKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <bitset>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
constexpr size_t dim_bitset_size = 64;
template <typename DeviceContext, typename T>
class FlipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename T>
class FlipKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");
auto x_dims = x->dims();
const int total_dims = x_dims.size();
std::bitset<dim_bitset_size> dim_bitset;
for (size_t i = 0; i < flip_dims.size(); ++i) {
int dim = flip_dims[i];
if (flip_dims[i] < 0) {
dim += total_dims;
}
dim_bitset[dim] = true;
}
auto x_strides = phi::stride(x_dims);
auto numel = x->numel();
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel; ++i) {
int64_t cur_indices = i;
int64_t rem = 0;
int64_t dst_offset = 0;
for (int d = 0; d < total_dims; ++d) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_strides[d];
rem = temp - cur_indices * x_strides[d];
dst_offset += dim_bitset[d]
? (x_dims[d] - 1 - cur_indices) * x_strides[d]
: cur_indices * x_strides[d];
cur_indices = rem;
}
out_data[i] = x_data[dst_offset];
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h"
#include <bitset>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
constexpr size_t dim_bitset_size = 64;
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
auto x_dims = x.dims();
const int total_dims = x_dims.size();
std::bitset<dim_bitset_size> dim_bitset;
for (size_t i = 0; i < axis.size(); ++i) {
int dim = axis[i];
if (axis[i] < 0) {
dim += total_dims;
}
dim_bitset[dim] = true;
}
auto x_strides = phi::stride(x_dims);
auto numel = x.numel();
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel; ++i) {
int64_t cur_indices = i;
int64_t rem = 0;
int64_t dst_offset = 0;
for (int d = 0; d < total_dims; ++d) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_strides[d];
rem = temp - cur_indices * x_strides[d];
dst_offset += dim_bitset[d] ? (x_dims[d] - 1 - cur_indices) * x_strides[d]
: cur_indices * x_strides[d];
cur_indices = rem;
}
out_data[i] = x_data[dst_offset];
}
}
} // namespace phi
PD_REGISTER_KERNEL(flip,
CPU,
ALL_LAYOUT,
phi::FlipKernel,
float,
double,
int32_t,
int64_t,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void flip_cuda_kernel(const int N,
const T* in_data,
T* out_data,
int64_t* x_shape,
int64_t* x_stride,
int* flip_dims,
int flip_dims_size,
int total_dims) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int cur_indices = idx, rem = 0, dst_offset = 0;
for (int i = 0; i < total_dims; ++i) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_stride[i];
rem = temp - cur_indices * x_stride[i];
// flip the indices if it is in flip_dims
for (int j = 0; j < flip_dims_size; ++j) {
if (i == flip_dims[j]) {
cur_indices = x_shape[i] - 1 - cur_indices;
}
}
dst_offset += cur_indices * x_stride[i];
cur_indices = rem;
}
out_data[idx] = in_data[dst_offset];
}
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
const auto gplace = dev_ctx.GetPlace();
auto cplace = phi::CPUPlace();
std::vector<int> flip_dims = axis;
auto* in_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x.dims();
const int total_dims = x_dims.size();
const int N = x.numel();
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims.size(); ++i) {
if (flip_dims[i] < 0) {
flip_dims[i] += total_dims;
}
}
auto x_stride = phi::stride(x_dims);
std::vector<int64_t> x_dims_v = phi::vectorize(x_dims);
std::vector<int64_t> x_stride_v = phi::vectorize(x_stride);
int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = paddle::memory::Alloc(dev_ctx, bytes);
int64_t* x_strides_array_gpu =
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_strides_array_gpu,
cplace,
x_stride_v.data(),
bytes,
dev_ctx.stream());
auto x_shape_array_tmp = paddle::memory::Alloc(dev_ctx, bytes);
int64_t* x_shape_array_gpu =
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_shape_array_gpu,
cplace,
x_dims_v.data(),
bytes,
dev_ctx.stream());
bytes = flip_dims_size * sizeof(int);
auto flip_dims_array_tmp = paddle::memory::Alloc(dev_ctx, bytes);
int* flip_dims_array_gpu = reinterpret_cast<int*>(flip_dims_array_tmp->ptr());
paddle::memory::Copy(gplace,
flip_dims_array_gpu,
cplace,
flip_dims.data(),
bytes,
dev_ctx.stream());
flip_cuda_kernel<T><<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(
N,
in_data,
out_data,
x_shape_array_gpu,
x_strides_array_gpu,
flip_dims_array_gpu,
flip_dims_size,
total_dims);
}
} // namespace phi
PD_REGISTER_KERNEL(flip,
GPU,
ALL_LAYOUT,
phi::FlipKernel,
float,
double,
phi::dtype::float16,
int,
int64_t,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册