convolution.h 7.9 KB
Newer Older
Z
zhangkaihuo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <set>

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
24
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
Z
zhangkaihuo 已提交
25 26 27 28

namespace phi {
namespace sparse {

29 30
using Dims4D = phi::funcs::sparse::Dims4D;

Z
zhangkaihuo 已提交
31 32 33
// such as: kernel(3, 3, 3), kernel_size = 27
// counter_per_weight: (kernel_size)
// TODO(zhangkaihuo): optimize performance with multithreading
34
template <typename T, typename Context, typename IntT = int>
Z
zhangkaihuo 已提交
35 36
void ProductRuleBook(const Context& dev_ctx,
                     const SparseCooTensor& x,
Z
zhangkaihuo 已提交
37
                     const std::vector<int>& kernel_sizes,
Z
zhangkaihuo 已提交
38 39 40 41
                     const std::vector<int>& paddings,
                     const std::vector<int>& dilations,
                     const std::vector<int>& strides,
                     const DDim& out_dims,
Z
zhangkaihuo 已提交
42
                     const bool subm,
Z
zhangkaihuo 已提交
43 44 45 46
                     DenseTensor* rulebook,
                     DenseTensor* counter_per_kernel) {
  const int64_t non_zero_num = x.nnz();
  const auto& non_zero_indices = x.non_zero_indices();
47
  const IntT* indices_ptr = non_zero_indices.data<IntT>();
Z
zhangkaihuo 已提交
48
  int* counter_ptr = counter_per_kernel->data<int>();
Z
zhangkaihuo 已提交
49
  int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
Z
zhangkaihuo 已提交
50 51 52 53 54 55
  memset(counter_ptr, 0, kernel_size * sizeof(int));

  int rulebook_len = 0;
  // calc the rulebook_len
  const auto& x_dims = x.dims();
  const Dims4D c_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]);
Z
zhangkaihuo 已提交
56 57
  const Dims4D c_kernel_dims(
      1, kernel_sizes[2], kernel_sizes[1], kernel_sizes[0]);
Z
zhangkaihuo 已提交
58 59 60 61 62
  const Dims4D c_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]);
  const Dims4D c_paddings(1, paddings[2], paddings[1], paddings[0]);
  const Dims4D c_strides(1, strides[2], strides[1], strides[0]);
  const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]);

63
  std::set<IntT> hash_in;
Z
zhangkaihuo 已提交
64 65
  if (subm) {
    for (int i = 0; i < non_zero_num; i++) {
66 67 68 69 70
      IntT batch = indices_ptr[i];
      IntT in_z = indices_ptr[i + non_zero_num];
      IntT in_y = indices_ptr[i + 2 * non_zero_num];
      IntT in_x = indices_ptr[i + 3 * non_zero_num];
      IntT index = phi::funcs::sparse::PointToIndex<DDim>(
71
          batch, in_x, in_y, in_z, x_dims);
Z
zhangkaihuo 已提交
72 73 74 75
      hash_in.insert(index);
    }
  }

76
  auto f_calc_rulebook = [&](IntT* rulebook_ptr) {
Z
zhangkaihuo 已提交
77
    int kernel_index = 0, rulebook_index = 0;
Z
zhangkaihuo 已提交
78 79 80
    for (int kz = 0; kz < kernel_sizes[0]; kz++) {
      for (int ky = 0; ky < kernel_sizes[1]; ky++) {
        for (int kx = 0; kx < kernel_sizes[2]; kx++) {
Z
zhangkaihuo 已提交
81
          ++kernel_index;
Z
zhangkaihuo 已提交
82
          for (int64_t i = 0; i < non_zero_num; i++) {
83 84 85 86 87 88 89
            IntT batch = indices_ptr[i];
            IntT in_z = indices_ptr[i + non_zero_num];
            IntT in_y = indices_ptr[i + 2 * non_zero_num];
            IntT in_x = indices_ptr[i + 3 * non_zero_num];
            IntT out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0];
            IntT out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1];
            IntT out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2];
90 91 92 93 94 95 96 97 98 99 100
            if (phi::funcs::sparse::Check(c_x_dims,
                                          c_kernel_dims,
                                          c_paddings,
                                          c_dilations,
                                          c_strides,
                                          in_x,
                                          in_y,
                                          in_z,
                                          kx,
                                          ky,
                                          kz)) {
Z
zhangkaihuo 已提交
101
              if (subm) {
102
                IntT out_index = phi::funcs::sparse::PointToIndex<DDim>(
103
                    batch, out_x, out_y, out_z, out_dims);
Z
zhangkaihuo 已提交
104 105 106 107 108
                if (hash_in.find(out_index) == hash_in.end()) {
                  continue;
                }
              }

Z
zhangkaihuo 已提交
109
              if (rulebook_ptr == nullptr) {
Z
zhangkaihuo 已提交
110
                counter_ptr[kernel_index - 1] += 1;
Z
zhangkaihuo 已提交
111 112
                ++rulebook_len;
              } else {
Z
zhangkaihuo 已提交
113
                rulebook_ptr[rulebook_index] = kernel_index - 1;
Z
zhangkaihuo 已提交
114 115
                rulebook_ptr[rulebook_index + rulebook_len] = i;  // in_i
                rulebook_ptr[rulebook_index + rulebook_len * 2] =
116
                    phi::funcs::sparse::PointToIndex<DDim>(
Z
zhangkaihuo 已提交
117 118 119 120 121 122 123 124 125 126 127 128
                        batch, out_x, out_y, out_z, out_dims);  // out_index
                ++rulebook_index;
              }
            }
          }
        }
      }
    }
  };

  f_calc_rulebook(nullptr);
  // alloc the rulebook
129 130 131 132 133 134
  *rulebook = phi::Empty(
      dev_ctx,
      DenseTensorMeta(paddle::experimental::CppTypeToDataType<IntT>::Type(),
                      {3, rulebook_len},
                      DataLayout::NCHW));
  IntT* rulebook_ptr = rulebook->data<IntT>();
Z
zhangkaihuo 已提交
135 136 137
  f_calc_rulebook(rulebook_ptr);
}

138
template <typename T, typename Context, typename IntT = int>
Z
zhangkaihuo 已提交
139 140 141 142 143 144 145
void UpdateRulebookAndOutIndex(const Context& dev_ctx,
                               const SparseCooTensor& x,
                               const int kernel_size,
                               const int out_channels,
                               const DDim& out_dims,
                               DenseTensor* rulebook,
                               SparseCooTensor* out) {
146
  std::set<IntT> out_indexs;
Z
zhangkaihuo 已提交
147
  int n = rulebook->dims()[1];
148
  IntT* rulebook_ptr = rulebook->data<IntT>();
Z
zhangkaihuo 已提交
149 150 151 152 153 154 155
  for (int i = 0; i < n; i++) {
    out_indexs.insert(rulebook_ptr[i + n * 2]);
  }

  int out_non_zero_num = out_indexs.size();
  const int64_t sparse_dim = 4;
  DenseTensorMeta indices_meta(
156 157 158
      paddle::experimental::CppTypeToDataType<IntT>::Type(),
      {sparse_dim, out_non_zero_num},
      DataLayout::NCHW);
159 160 161
  DenseTensorMeta values_meta(x.dtype(),
                              {out_non_zero_num, out_channels},
                              x.non_zero_elements().layout());
Z
zhangkaihuo 已提交
162 163
  phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
  phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
164
  IntT* out_indices_ptr = out_indices.data<IntT>();
Z
zhangkaihuo 已提交
165 166
  int i = 0;
  for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {
167 168
    const IntT index = *it;
    IntT batch, x, y, z;
169
    phi::funcs::sparse::IndexToPoint<DDim>(index, out_dims, &batch, &x, &y, &z);
Z
zhangkaihuo 已提交
170 171 172 173 174 175
    out_indices_ptr[i] = batch;
    out_indices_ptr[i + out_non_zero_num] = z;
    out_indices_ptr[i + out_non_zero_num * 2] = y;
    out_indices_ptr[i + out_non_zero_num * 3] = x;
  }
  for (i = 0; i < n; i++) {
176
    IntT out_index = rulebook_ptr[i + n * 2];
Z
zhangkaihuo 已提交
177 178 179 180 181 182 183
    rulebook_ptr[i + n * 2] =
        std::distance(out_indexs.begin(), out_indexs.find(out_index));
  }

  out->SetMember(out_indices, out_values, out_dims, true);
}

184
template <typename T, typename IntT = int>
Z
zhangkaihuo 已提交
185
void Gather(
186
    const T* x, const IntT* indexs, const int n, const int channels, T* out) {
Z
zhangkaihuo 已提交
187
  for (int i = 0; i < n; i++) {
188
    IntT real_i = indexs[i];
Z
zhangkaihuo 已提交
189 190 191 192
    memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T));
  }
}

193
template <typename T, typename IntT = int>
Z
zhangkaihuo 已提交
194
void Scatter(
195
    const T* x, const IntT* indexs, const int n, const int channels, T* out) {
Z
zhangkaihuo 已提交
196
  for (int i = 0; i < n; i++) {
197
    IntT real_i = indexs[i];
Z
zhangkaihuo 已提交
198 199 200 201 202 203 204 205
    for (int j = 0; j < channels; j++) {
      out[real_i * channels + j] += x[i * channels + j];
    }
  }
}

}  // namespace sparse
}  // namespace phi