shuffle_channel_op.cu 4.8 KB
Newer Older
S
shippingwang 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
S
shippingwang 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/shuffle_channel_op.h"
13 14
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
S
shippingwang 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaximumNumBlocks);
}

template <typename T>
29 30 31 32 33 34 35
__global__ void ShuffleChannel(const int nthreads,
                               const int feature_map_size,
                               T* output,
                               const T* input,
                               int group_row,
                               int group_column,
                               int len) {
S
shippingwang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t ii = index; ii < nthreads; ii += offset) {
    const int n = index / group_row / group_column / len;
    const int i = (index / group_column / len) % group_row;
    const int j = index / len % group_column;
    const int k = index - (n * feature_map_size + (i * group_column + j) * len);
    T* p_o = output + n * feature_map_size + (j * group_row + i) * len;
    p_o[k] = input[index];
  }
}
template <typename DeviceContext, typename T>
class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<framework::Tensor>("X");
    auto* output = ctx.Output<framework::Tensor>("Out");
    int group = ctx.Attr<int>("group");

    auto input_dims = input->dims();
    auto num = input_dims[0];
    auto channel = input_dims[1];
    auto height = input_dims[2];
    auto weight = input_dims[3];

    auto feature_map_size = channel * height * weight;
    auto sp_sz = height * weight;
    int group_row = group;
    int group_column = channel / group_row;
    // count is the product of NCHW same as numel()
    int count = num * group_column * group_row * sp_sz;

    int blocks = NumBlocks(output->numel());
    int threads = kNumCUDAThreads;

    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());

74 75
    ShuffleChannel<T>
        <<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
76 77 78 79 80 81 82
            count,
            feature_map_size,
            output_data,
            input_data,
            group_row,
            group_column,
            sp_sz);
S
shippingwang 已提交
83 84 85 86 87 88 89
  }
};

template <typename DeviceContext, typename T>
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
90 91 92 93 94
    auto* output_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* input_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));

S
shippingwang 已提交
95
    int group = ctx.Attr<int>("group");
S
shippingwang 已提交
96

97
    const auto& input_dims = input_grad->dims();
S
shippingwang 已提交
98 99 100 101 102 103 104 105 106
    auto num = input_dims[0];
    auto channel = input_dims[1];
    auto height = input_dims[2];
    auto weight = input_dims[3];
    auto feature_map_size = channel * height * weight;
    auto sp_sz = height * weight;

    int group_row = group;
    int group_column = channel / group_row;
107

S
shippingwang 已提交
108 109 110 111 112 113
    T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
    const T* output_grad_data = output_grad->data<T>();

    int blocks = NumBlocks(output_grad->numel());
    int threads = kNumCUDAThreads;
    int count = num * group_column * group_row * sp_sz;
S
shippingwang 已提交
114

115 116
    ShuffleChannel<T>
        <<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
117 118 119 120 121 122 123
            count,
            feature_map_size,
            input_grad_data,
            output_grad_data,
            group_row,
            group_column,
            sp_sz);
S
shippingwang 已提交
124 125 126 127
  }
};
}  // namespace operators
}  // namespace paddle
S
shippingwang 已提交
128 129 130

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
S
shippingwang 已提交
131
    shuffle_channel,
L
Leo Chen 已提交
132 133
    ops::ShuffleChannelOpCUDAKernel<phi::GPUContext, float>,
    ops::ShuffleChannelOpCUDAKernel<phi::GPUContext, double>);
S
shippingwang 已提交
134
REGISTER_OP_CUDA_KERNEL(
S
shippingwang 已提交
135
    shuffle_channel_grad,
L
Leo Chen 已提交
136 137
    ops::ShuffleChannelGradOpCUDAKernel<phi::GPUContext, float>,
    ops::ShuffleChannelGradOpCUDAKernel<phi::GPUContext, double>);