convolution_kernel.cu 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
18
#include "paddle/phi/core/visit_type.h"
19
#include "paddle/phi/kernels/funcs/blas/blas.h"
20
#include "paddle/phi/kernels/funcs/scatter.cu.h"
21
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
22
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
23
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
24 25 26 27

namespace phi {
namespace sparse {

28 29 30 31 32 33 34 35 36 37 38
template <typename T, typename IntT>
void Conv3dGPUKernel(const GPUContext& dev_ctx,
                     const SparseCooTensor& x,
                     const DenseTensor& kernel,
                     const std::vector<int>& paddings,
                     const std::vector<int>& dilations,
                     const std::vector<int>& strides,
                     const int groups,
                     const bool subm,
                     SparseCooTensor* out,
                     DenseTensor* rulebook) {
39 40 41 42 43 44 45
  // update padding and dilation
  // Currently, only support x.layout is NDHWC, groups = 1
  // if x.layout != NDHWC then transpose(x), transpose(weight)
  const auto& x_dims = x.dims();
  const auto& kernel_dims = kernel.dims();
  int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
  DDim out_dims = {1, 1, 1, 1, 1};
Z
zhangkaihuo 已提交
46 47 48 49
  std::vector<int> kernel_sizes(kernel_dims.size());
  for (int i = 0; i < kernel_dims.size(); i++) {
    kernel_sizes[i] = kernel_dims[i];
  }
50 51 52 53 54 55 56 57 58

  std::vector<int> subm_paddings(paddings), subm_strides(strides);
  if (subm) {
    // the out shape of subm_conv is same as input shape
    // reset the padding=kernel_size/2 and strides=1
    phi::funcs::sparse::ResetSubmKernelSizeAndStrides(
        kernel.dims(), &subm_paddings, &subm_strides);
  }

59
  phi::funcs::sparse::GetOutShape(
60
      x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims);
61 62 63 64 65 66 67 68 69 70 71 72 73
  const int in_channels = kernel_dims[3];
  const int out_channels = kernel_dims[4];
  std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size);

  // Second algorithm:
  // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
  // 1. product rulebook
  DenseTensorMeta counter_meta(
      DataType::INT32, {kernel_size}, DataLayout::NCHW);
  DenseTensorMeta offsets_meta(
      DataType::INT32, {kernel_size}, DataLayout::NCHW);
  DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
  DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta));
Z
zhangkaihuo 已提交
74 75 76 77
  DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW);
  DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
  DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
  int n = ProductRuleBook<T, GPUContext, IntT>(dev_ctx,
                                               x,
                                               kernel_sizes,
                                               subm_paddings,
                                               dilations,
                                               subm_strides,
                                               out_dims,
                                               subm,
                                               rulebook,
                                               &counter_per_kernel,
                                               &offsets_per_kernel,
                                               &out_index,
                                               &unique_value,
                                               out,
                                               &h_counter,
                                               &offsets);
94 95 96

  const int* counter_ptr = counter_per_kernel.data<int>();
  const int* offsets_ptr = counter_per_kernel.data<int>();
97
  const IntT* rulebook_ptr = rulebook->data<IntT>();
98 99 100 101 102 103 104 105 106 107 108 109

  // 2. gather
  DenseTensorMeta in_features_meta(
      x.dtype(), {n, in_channels}, DataLayout::NCHW);
  DenseTensorMeta out_features_meta(
      x.dtype(), {n, out_channels}, DataLayout::NCHW);
  phi::DenseTensor in_features =
      phi::Empty(dev_ctx, std::move(in_features_meta));
  phi::DenseTensor out_features =
      phi::Empty(dev_ctx, std::move(out_features_meta));
  T* in_features_ptr = in_features.data<T>();
  T* out_features_ptr = out_features.data<T>();
110
  phi::funcs::SetConstant<GPUContext, T> set_zero;
Z
zhangkaihuo 已提交
111
  set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
112 113 114

  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1);
115 116 117 118 119 120 121 122
  GatherKernel<T, IntT><<<config.block_per_grid.x,
                          config.thread_per_block.x,
                          0,
                          dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
                                              rulebook_ptr + n,
                                              in_features_ptr,
                                              n,
                                              in_channels);
123 124

  // 3. call gemm for every werght
125
  auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  auto* out_values = out->mutable_non_zero_elements();
  T* out_values_ptr = out_values->data<T>();

  const T* kernel_ptr = kernel.data<T>();
  for (int i = 0; i < kernel_size; i++) {
    if (h_counter[i] <= 0) {
      continue;
    }

    // call gemm: (n, in_channels) * (in_channels, out_channels)
    const int M = h_counter[i];
    const int K = in_channels;
    const int N = out_channels;
    T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
    const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
    T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels;

    blas.GEMM(CblasNoTrans,
              CblasNoTrans,
              M,
              N,
              K,
              static_cast<T>(1),
              tmp_in_ptr,
              tmp_kernel_ptr,
              static_cast<T>(0),
              tmp_out_ptr);
  }

  // 4. scatter
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
  if (subm) {
    set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
    config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1);
    phi::funcs::ScatterCUDAKernel<T, IntT><<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
        out_features_ptr,
        rulebook_ptr + 2 * n,
        out_values_ptr,
        n,
        out_channels,
        false);
  } else {
    config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, out->nnz() * out_channels, 1);
173 174 175 176 177 178 179 180 181 182 183
    phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid.x,
                                           config.thread_per_block.x,
                                           0,
                                           dev_ctx.stream()>>>(
        out_features_ptr,
        unique_value.data<int>(),
        out_index.data<int>(),
        out->nnz(),
        n,
        out_channels,
        out_values_ptr);
184
  }
185
}
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
/**
 * x: (N, D, H, W, C)
 * kernel: (D, H, W, C, OC)
 * out: (N, D, H, W, OC)
**/
template <typename T, typename Context>
void Conv3dKernel(const Context& dev_ctx,
                  const SparseCooTensor& x,
                  const DenseTensor& kernel,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
                  const std::vector<int>& strides,
                  const int groups,
                  const bool subm,
                  SparseCooTensor* out,
                  DenseTensor* rulebook) {
202
  PD_VISIT_INTEGRAL_TYPES(
203 204 205 206 207 208 209 210 211 212 213 214 215
      x.non_zero_indices().dtype(), "Conv3dGPUKernel", ([&] {
        Conv3dGPUKernel<T, data_t>(dev_ctx,
                                   x,
                                   kernel,
                                   paddings,
                                   dilations,
                                   strides,
                                   groups,
                                   subm,
                                   out,
                                   rulebook);
      }));
}
216 217 218 219 220 221 222 223 224 225 226 227 228

}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(sparse_conv3d,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::Conv3dKernel,
                   float,
                   double,
                   phi::dtype::float16) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}