未验证 提交 0d78e491 编写于 作者: Z zhangkaihuo 提交者: GitHub

Submanifold convolution (#40363)

submanifold convolution
上级 17d8a5e0
......@@ -32,6 +32,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
DenseTensor* kernel_grad);
......@@ -44,7 +45,8 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups) {
const int groups,
const bool subm) {
DenseTensor x_grad =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
DenseTensor kernel_grad = phi::Empty<Context>(
......@@ -59,6 +61,7 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
dilations,
strides,
groups,
subm,
&x_grad,
&kernel_grad);
std::vector<DenseTensor> out(2);
......
......@@ -125,6 +125,7 @@ void Conv3dKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
SparseCooTensor* out,
DenseTensor* rulebook);
......@@ -136,14 +137,23 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* rulebook) {
DenseTensor indices = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor values =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
SparseCooTensor coo(indices, values, x.dims());
Conv3dKernel<T, Context>(
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);
Conv3dKernel<T, Context>(dev_ctx,
x,
kernel,
paddings,
dilations,
strides,
groups,
subm,
&coo,
rulebook);
return coo;
}
......
......@@ -39,6 +39,7 @@ void ProductRuleBook(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const DDim& out_dims,
const bool subm,
DenseTensor* rulebook,
DenseTensor* counter_per_kernel) {
const auto& kernel_dims = kernel.dims();
......@@ -59,11 +60,24 @@ void ProductRuleBook(const Context& dev_ctx,
const Dims4D c_strides(1, strides[2], strides[1], strides[0]);
const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]);
std::set<int> hash_in;
if (subm) {
for (int i = 0; i < non_zero_num; i++) {
int batch = indices_ptr[i];
int in_z = indices_ptr[i + non_zero_num];
int in_y = indices_ptr[i + 2 * non_zero_num];
int in_x = indices_ptr[i + 3 * non_zero_num];
int index = PointToIndex<DDim>(batch, in_x, in_y, in_z, x_dims);
hash_in.insert(index);
}
}
auto f_calc_rulebook = [&](int* rulebook_ptr) {
int kernel_index = 0, rulebook_index = 0;
for (int kz = 0; kz < kernel_dims[0]; kz++) {
for (int ky = 0; ky < kernel_dims[1]; ky++) {
for (int kx = 0; kx < kernel_dims[2]; kx++) {
++kernel_index;
for (int64_t i = 0; i < non_zero_num; i++) {
int batch = indices_ptr[i];
int in_z = indices_ptr[i + non_zero_num];
......@@ -83,11 +97,19 @@ void ProductRuleBook(const Context& dev_ctx,
kx,
ky,
kz)) {
if (subm) {
int out_index =
PointToIndex<DDim>(batch, out_x, out_y, out_z, out_dims);
if (hash_in.find(out_index) == hash_in.end()) {
continue;
}
}
if (rulebook_ptr == nullptr) {
counter_ptr[kernel_index] += 1;
counter_ptr[kernel_index - 1] += 1;
++rulebook_len;
} else {
rulebook_ptr[rulebook_index] = kernel_index;
rulebook_ptr[rulebook_index] = kernel_index - 1;
rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i
rulebook_ptr[rulebook_index + rulebook_len * 2] =
PointToIndex<DDim>(
......@@ -96,7 +118,6 @@ void ProductRuleBook(const Context& dev_ctx,
}
}
}
++kernel_index;
}
}
}
......
......@@ -38,6 +38,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
......@@ -70,32 +71,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
T* d_kernel_ptr = kernel_grad->data<T>();
memset(d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel());
Gather<T>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
in_channels,
in_features_ptr);
Gather<T>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
rulebook_len,
out_channels,
out_grad_features_ptr);
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
x_grad->Resize(x.non_zero_elements().dims());
dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel());
T* x_grad_values_ptr = x_grad->data<T>();
memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel());
memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel());
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0);
for (int i = 0; i < rulebook_len; i++) {
counter[rulebook_ptr[i]] += 1;
}
int offset = 0;
int offset = 0, max_count = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter[i];
if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]);
}
}
offsets[kernel_size] = offset;
if (subm) {
blas.GEMM(CblasTrans,
CblasNoTrans,
x.non_zero_elements().dims()[1],
out_grad.non_zero_elements().dims()[1],
x.non_zero_elements().dims()[0],
static_cast<T>(1),
x.non_zero_elements().data<T>(),
out_grad.non_zero_elements().data<T>(),
static_cast<T>(0),
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
T* x_grad_ptr = x_grad->data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
out_grad.non_zero_elements().dims()[0],
in_channels,
out_grad.non_zero_elements().dims()[1],
static_cast<T>(1),
out_grad.non_zero_elements().data<T>(),
kernel.data<T>() + half_kernel_size * in_channels * out_channels,
static_cast<T>(0),
x_grad_ptr);
if (max_count == 0) {
return;
}
}
Gather<T>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
in_channels,
in_features_ptr);
Gather<T>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
rulebook_len,
out_channels,
out_grad_features_ptr);
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) {
continue;
}
......@@ -136,10 +177,6 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
// 4. scatter
x_grad->Resize(x.non_zero_elements().dims());
dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel());
T* x_grad_values_ptr = x_grad->data<T>();
memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel());
Scatter<T>(d_x_features_ptr,
rulebook.data<int>() + rulebook_len,
rulebook_len,
......
......@@ -35,6 +35,7 @@ void Conv3dKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
SparseCooTensor* out,
DenseTensor* rulebook) {
// update padding and dilation
......@@ -63,6 +64,7 @@ void Conv3dKernel(const Context& dev_ctx,
dilations,
strides,
out_dims,
subm,
rulebook,
&counter_per_kernel);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <set>
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/submanifold_convolution_kernel.h"
namespace phi {
namespace sparse {} // namespace sparse
} // namespace phi
......@@ -71,7 +71,8 @@ __global__ void ScatterKernel(const T* input,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out) {
T* out,
const bool subm = false) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
......@@ -82,6 +83,9 @@ __global__ void ScatterKernel(const T* input,
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
......
......@@ -43,6 +43,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
......@@ -69,37 +70,18 @@ void Conv3dGradKernel(const Context& dev_ctx,
T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>();
kernel_grad->Resize(kernel_dims);
dev_ctx.Alloc(
kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T));
kernel_grad->ResizeAndAllocate(kernel_dims);
T* d_kernel_ptr = kernel_grad->data<T>();
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, kernel_grad, static_cast<T>(0.0f));
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
T* x_grad_values_ptr = x_grad->data<T>();
set_zero(dev_ctx, x_grad, static_cast<T>(0.0f));
set_zero(dev_ctx, &d_x_features, static_cast<T>(0.0f));
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
......@@ -117,16 +99,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
for (int i = 0; i < rulebook_len; i++) {
counter[h_counter[i]] += 1;
}
int offset = 0;
int offset = 0, max_count = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter[i];
if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]);
}
}
offsets[kernel_size] = offset;
if (subm) {
blas.GEMM(CblasTrans,
CblasNoTrans,
x.non_zero_elements().dims()[1],
out_grad.non_zero_elements().dims()[1],
x.non_zero_elements().dims()[0],
static_cast<T>(1),
x.non_zero_elements().data<T>(),
out_grad.non_zero_elements().data<T>(),
static_cast<T>(0),
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
T* x_grad_ptr = x_grad->data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
out_grad.non_zero_elements().dims()[0],
in_channels,
out_grad.non_zero_elements().dims()[1],
static_cast<T>(1),
out_grad.non_zero_elements().data<T>(),
kernel.data<T>() + half_kernel_size * in_channels * out_channels,
static_cast<T>(0),
x_grad_ptr);
if (max_count == 0) {
return;
}
}
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) {
continue;
}
......@@ -167,19 +205,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
// 4. scatter
x_grad->Resize(x.non_zero_elements().dims());
dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel());
T* x_grad_values_ptr = x_grad->data<T>();
DenseTensor out_index = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
DenseTensor unique_key = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));
SortedAndUniqueIndex(dev_ctx,
rulebook_ptr + rulebook_len,
......@@ -200,7 +230,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
x.nnz(),
rulebook_len,
in_channels,
x_grad_values_ptr);
x_grad_values_ptr,
subm);
}
} // namespace sparse
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
......@@ -32,6 +33,34 @@ limitations under the License. */
namespace phi {
namespace sparse {
__global__ void SetFlagAndUpdateCounterKernel(const int* indexs,
const int n,
const int rulebook_len,
const int kernel_size,
int* rulebook_ptr,
int* counter_ptr) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
extern __shared__ int cache_count[]; // kernel_size
for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
cache_count[i] = 0;
}
__syncthreads();
for (int i = tid; i < n; i += gridDim.x * blockDim.x) {
int index = indexs[i];
int kernel_index = rulebook_ptr[index];
rulebook_ptr[index + rulebook_len] = -1;
rulebook_ptr[index + 2 * rulebook_len] = -1;
rulebook_ptr[index] = -1;
atomicAdd(&cache_count[kernel_index], 1);
}
__syncthreads();
for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
atomicSub(&counter_ptr[i], cache_count[i]);
}
}
/**
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
......@@ -95,8 +124,10 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
const Dims4D paddings,
const Dims4D dilations,
const Dims4D strides,
const bool subm,
int* rulebook,
int* counter) {
int* counter,
int* in_indexs) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
extern __shared__ int counter_buf[]; // kernel_size
const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1];
......@@ -108,13 +139,16 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
int kernel_index = 0;
int batch = x_indices[i];
int in_z = x_indices[i + non_zero_num];
int in_y = x_indices[i + 2 * non_zero_num];
int in_x = x_indices[i + 3 * non_zero_num];
if (subm) {
in_indexs[i] = PointToIndex(batch, in_x, in_y, in_z, x_dims);
}
for (int kz = 0; kz < kernel_dims[1]; kz++) {
for (int ky = 0; ky < kernel_dims[2]; ky++) {
for (int kx = 0; kx < kernel_dims[3]; kx++) {
int batch = x_indices[i];
int in_z = x_indices[i + non_zero_num];
int in_y = x_indices[i + 2 * non_zero_num];
int in_x = x_indices[i + 3 * non_zero_num];
int in_i = -1, out_index = -1, kernel_i = -1;
if (Check(x_dims,
kernel_dims,
......@@ -182,6 +216,7 @@ int ProductRuleBook(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const DDim& out_dims,
const bool subm,
DenseTensor* rulebook,
DenseTensor* counter_per_kernel,
DenseTensor* offsets_per_kernel,
......@@ -195,13 +230,14 @@ int ProductRuleBook(const Context& dev_ctx,
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
DenseTensor in_indexs = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW));
int* counter_ptr = counter_per_kernel->data<int>();
int* offsets_ptr = offsets_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int rulebook_rows = 3;
const int rulebook_cols = kernel_size * non_zero_num;
rulebook->ResizeAndAllocate({rulebook_rows, rulebook_cols});
dev_ctx.Alloc(rulebook, rulebook->dtype(), sizeof(int) * rulebook->numel());
int* rulebook_ptr = rulebook->data<int>();
const auto x_dims = x.dims();
......@@ -229,8 +265,10 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings,
d_dilations,
d_strides,
subm,
rulebook_ptr,
counter_ptr);
counter_ptr,
in_indexs.data<int>());
// 2. remove -1
#ifdef PADDLE_WITH_HIP
......@@ -242,6 +280,144 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_ptr + rulebook_rows * rulebook_cols,
-1);
DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>(
rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1);
int rulebook_len = 0;
phi::backends::gpu::GpuMemcpyAsync(
&rulebook_len,
rulebook_ptr + 3 * kernel_size * non_zero_num - 1,
sizeof(int),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
rulebook_len /= 3;
dev_ctx.Wait();
if (subm) {
// At present, hashtable is not used to map the input and output indexes.
// At present, the intermediate output index is generated by normal
// convolution,
// and then the intermediate output index is subtracted from the input index
// to obain the rulebook.
// get difference
int32_t* A_key_ptr = rulebook_ptr + 2 * rulebook_len;
int32_t* B_key_ptr = in_indexs.data<int>();
DenseTensor A_val = phi::Empty<Context>(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
DenseTensor B_val = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW));
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, &A_val, kps::IdentityFunctor<int>());
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, &B_val, kps::IdentityFunctor<int>());
DenseTensor key_result = phi::Empty<Context>(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len + 1}, DataLayout::NCHW));
DenseTensor val_result = phi::Empty<Context>(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
#ifdef PADDLE_WITH_HIP
thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
counter_ptr,
counter_ptr + kernel_size,
offsets_ptr);
std::vector<int> offsets(kernel_size, 0);
// TODO(zhangkaihuo): used unified memcpy interface
phi::backends::gpu::GpuMemcpyAsync(offsets.data(),
offsets_ptr,
kernel_size * sizeof(int),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
thrust::pair<int*, int*> end;
// Because set_diff does not support duplicate data, set_diff is performed
// separately for each segment of data.
// TODO(zhangkaihuo): Using hashtable here may get better performance,
// further tests ared needed.
for (int i = 0; i < kernel_size; i++) {
int start = offsets[i];
int stop = i == kernel_size - 1 ? rulebook_len : offsets[i + 1];
int* key_result_start = (i == 0 ? key_result.data<int>() : end.first);
int* val_result_start = i == 0 ? val_result.data<int>() : end.second;
end =
#ifdef PADDLE_WITH_HIP
thrust::set_difference_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::set_difference_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
A_key_ptr + start,
A_key_ptr + stop,
B_key_ptr,
B_key_ptr + x.nnz(),
A_val.data<int>() + start,
B_val.data<int>(),
key_result_start,
val_result_start);
}
DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>(
key_result.data<int>(),
end.first,
key_result.data<int>() + rulebook_len);
int len = 0;
phi::backends::gpu::GpuMemcpyAsync(&len,
key_result.data<int>() + rulebook_len,
sizeof(int),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
// set the diff value = -1, and update counter
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len, 1);
SetFlagAndUpdateCounterKernel<<<config.block_per_grid.x,
config.thread_per_block,
kernel_size * sizeof(int),
dev_ctx.stream()>>>(val_result.data<int>(),
len,
rulebook_len,
kernel_size,
rulebook_ptr,
counter_ptr);
// remove -1
#ifdef PADDLE_WITH_HIP
int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
rulebook_ptr,
rulebook_ptr + 3 * rulebook_len,
-1);
DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>(
rulebook_ptr, last, key_result.data<int>() + rulebook_len);
phi::backends::gpu::GpuMemcpyAsync(&rulebook_len,
key_result.data<int>() + rulebook_len,
sizeof(int),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
rulebook_len /= 3;
}
#ifdef PADDLE_WITH_HIP
thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
......@@ -274,23 +450,14 @@ int ProductRuleBook(const Context& dev_ctx,
cudaMemcpyDeviceToHost,
dev_ctx.stream());
#endif
dev_ctx.Wait();
int rulebook_len =
(*h_counter)[kernel_size - 1] + (*h_offsets)[kernel_size - 1];
rulebook->Resize({rulebook_rows, rulebook_len});
// 3. sorted or merge the out index
out_index->ResizeAndAllocate({rulebook_len});
unique_value->ResizeAndAllocate({rulebook_len});
unique_key->ResizeAndAllocate({rulebook_len});
dev_ctx.Alloc(
out_index, out_index->dtype(), sizeof(int) * out_index->numel());
int* out_index_ptr = out_index->data<int>();
dev_ctx.Alloc(
unique_value, unique_value->dtype(), sizeof(int) * unique_value->numel());
int* unique_value_ptr = unique_value->data<int>();
dev_ctx.Alloc(
unique_key, unique_key->dtype(), sizeof(int) * unique_key->numel());
int* unique_key_ptr = unique_key->data<int>();
int* new_end = SortedAndUniqueIndex(dev_ctx,
......@@ -364,6 +531,7 @@ void Conv3dKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
SparseCooTensor* out,
DenseTensor* rulebook) {
// update padding and dilation
......@@ -389,20 +557,28 @@ void Conv3dKernel(const Context& dev_ctx,
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta));
DenseTensor out_index = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor unique_key = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));
std::vector<int> subm_paddings(paddings), subm_strides(strides);
if (subm) {
auto kernel_dims = kernel.dims();
for (int i = 0; i < paddings.size(); i++) {
subm_paddings[i] = kernel_dims[i] / 2;
subm_strides[i] = 1;
}
}
int n = ProductRuleBook<T, Context>(dev_ctx,
x,
kernel,
paddings,
subm_paddings,
dilations,
strides,
subm_strides,
out_dims,
subm,
rulebook,
&counter_per_kernel,
&offsets_per_kernel,
......@@ -428,6 +604,8 @@ void Conv3dKernel(const Context& dev_ctx,
phi::Empty(dev_ctx, std::move(out_features_meta));
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1);
......
......@@ -78,7 +78,7 @@ void TestConv3dBase(const std::vector<int>& indices,
if (!std::is_same<T, phi::dtype::float16>::value) {
auto outs = paddle::experimental::sparse::conv3d(
x, weight, paddings, dilations, strides, 1);
x, weight, paddings, dilations, strides, 1, false);
auto out = std::dynamic_pointer_cast<phi::SparseCooTensor>(
std::get<0>(outs).impl());
......
......@@ -64,7 +64,8 @@ void TestConv3dBase(const std::vector<int>& indices,
const float diff = 1e-3,
const bool backward = false,
const std::vector<T> features_grad = {},
const std::vector<T> kernel_grad = {}) {
const std::vector<T> kernel_grad = {},
const bool subm = false) {
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
......@@ -114,6 +115,7 @@ void TestConv3dBase(const std::vector<int>& indices,
dilations,
strides,
1,
subm,
&rulebook);
ASSERT_EQ(correct_out_dims.size(), out.dims().size());
......@@ -138,7 +140,8 @@ void TestConv3dBase(const std::vector<int>& indices,
paddings,
dilations,
strides,
1);
1,
subm);
f_verify(grads[0].data<T>(), features_grad);
f_verify(grads[1].data<T>(), kernel_grad);
}
......@@ -191,6 +194,7 @@ void TestConv3dBase(const std::vector<int>& indices,
dilations,
strides,
1,
subm,
&d_rulebook);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
......@@ -235,7 +239,8 @@ void TestConv3dBase(const std::vector<int>& indices,
paddings,
dilations,
strides,
1);
1,
subm);
DenseTensor h_features_grad = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(grads[0].dtype(), grads[0].dims(), grads[0].layout()));
......@@ -266,7 +271,8 @@ void TestConv3d(const std::vector<int>& indices,
const float diff = 1e-3,
const bool backward = false,
const std::vector<float> features_grad = {},
const std::vector<float> kernel_grad = {}) {
const std::vector<float> kernel_grad = {},
const bool subm = false) {
// test float
TestConv3dBase<float>(indices,
features,
......@@ -283,7 +289,8 @@ void TestConv3d(const std::vector<int>& indices,
diff,
backward,
features_grad,
kernel_grad);
kernel_grad,
subm);
// test double
TestConv3dBase<double>(indices,
cast<float, double>(features),
......@@ -300,7 +307,8 @@ void TestConv3d(const std::vector<int>& indices,
diff,
backward,
cast<float, double>(features_grad),
cast<float, double>(kernel_grad));
cast<float, double>(kernel_grad),
subm);
}
TEST(DEV_API, sparse_conv3d) {
......@@ -661,5 +669,101 @@ TEST(DEV_API, sparse_conv3d_backward) {
kernel_grad);
}
TEST(DEV_API, sparse_conv2d_subm) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 1, 4, 5, in_channels};
DDim kernel_dims = {1, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 1, 4, 5, out_channels};
std::vector<int> paddings = {0, 1, 1};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 4;
std::vector<int> indices_flatten = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 3, 2, 2, 3};
std::vector<float> features = {0.8854, 0.6505, -0.1999, 0.3583};
// 3*3*3=27
std::vector<float> kernel = {
0.9364, 0.9460, 0.6564, 0.7999, 0.2013, 0.3812, 0.5474, 0.1016, 0.3368};
std::vector<int> out_indices_flatten = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 3, 2, 2, 3};
std::vector<float> out_features = {0.1782, 0.2313, 0.7117, 0.5214};
std::vector<float> features_grad = {0.0359, 1.2080, 0.5838, 0.4541};
std::vector<float> kernel_grad = {
0.3391, 0.4630, 0.0000, -0.1042, 0.3528, 0.2550, 0.0000, -0.0462, 0.0829};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations,
1e-3,
true,
features_grad,
kernel_grad,
true);
}
TEST(DEV_API, sparse_conv3d_subm) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 4, 4, 5, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 4, 4, 5, out_channels};
std::vector<int> paddings = {1, 1, 1};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 3;
std::vector<int> indices_flatten = {0, 0, 0, 1, 3, 3, 2, 0, 2, 0, 3, 1};
std::vector<float> features = {-0.9578, 0.1572, 0.1036};
// 3*3*3=27
std::vector<float> kernel = {
0.1367, 0.4534, 0.2138, 0.8264, 0.7534, 0.3270, 0.2880, 0.1562, 0.7770,
0.6902, 0.1981, 0.1369, 0.6582, 0.7582, 0.5640, 0.8894, 0.7350, 0.1845,
0.6892, 0.3654, 0.6076, 0.0326, 0.8412, 0.5289, 0.9824, 0.8235, 0.9802};
std::vector<int> out_indices_flatten = {0, 0, 0, 1, 3, 3, 2, 0, 2, 0, 3, 1};
std::vector<float> out_features = {-0.7262, 0.1192, 0.0785};
std::vector<float> features_grad = {-0.5506, 0.0904, 0.0595};
std::vector<float> kernel_grad = {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.7224, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations,
1e-3,
true,
features_grad,
kernel_grad,
true);
}
} // namespace tests
} // namespace phi
- api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups)
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel :
func : sparse_conv3d
......
- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
kernel :
func : sparse_conv_grad
func : sparse_conv3d_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册