diff --git a/paddle/fluid/operators/flip_op.cu b/paddle/fluid/operators/flip_op.cu index 26b3d11bc6c7b72720ad14c962bfc713779340ab..2391d4b907a6030c6d531347d9321f6af4a64428 100644 --- a/paddle/fluid/operators/flip_op.cu +++ b/paddle/fluid/operators/flip_op.cu @@ -24,24 +24,6 @@ namespace operators { using Tensor = framework::Tensor; using CUDADeviceContext = paddle::platform::CUDADeviceContext; -template -__global__ void kernel_pointwise_flip_apply(const int N, const T* in_data, - T* out_data, int dim0, int stride0, - int dim1, int flip_dim) { - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N; - idx += gridDim.x * blockDim.x) { - int dst_offset = 0; - if (flip_dim == 0) { - // flip 1st dim - dst_offset = (dim0 - 1 - idx / stride0) * stride0 + idx % stride0; - } else { - // flip last dim - dst_offset = idx / stride0 * stride0 + (dim1 - 1 - idx % stride0); - } - out_data[dst_offset] = in_data[idx]; - } -} - template __global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data, int64_t* x_shape, int64_t* x_stride, @@ -103,29 +85,6 @@ class FlipKernel std::vector x_dims_v = framework::vectorize(x_dims); std::vector x_stride_v = framework::vectorize(x_stride); - // wrap high-dims to 2-dims - if (flip_dims_size == 1 && - (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { - int dim0 = 1, dim1 = 1; - int stride0 = 1; - if (flip_dims[0] == 0) { - dim0 = x_dims_v[0]; - stride0 = x_stride_v[0]; - for (size_t i = 1; i < total_dims; ++i) { - dim1 *= x_dims_v[i]; - } - } else { - dim1 = x_dims_v[total_dims - 1]; - for (size_t i = 0; i < total_dims - 1; ++i) { - dim0 *= x_dims_v[i]; - } - stride0 *= x_dims_v[total_dims - 1]; - } - kernel_pointwise_flip_apply< - T><<>>( - N, in_data, out_data, dim0, stride0, dim1, flip_dims[0]); - } - int bytes = total_dims * sizeof(int64_t); auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes); int64_t* x_strides_array_gpu =