未验证 提交 bc99a76c 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add sparse conv3d kernel (#39879)

* fix incorrect dims settings

* sparse conv3d

* fix out dims

* test performance

* test large shape success

* opt scatter, double performance

* test float16

* remove profiling code

* remove pten

* opt code lines

* correct boundary judgment

* only cpu

* test ci

* test ci

* remove the including paddle/fluid header; extract the conmmon function

* opt code lines

* use DenseTensor::data() instead of mutable_data

* return rulebook for backward

* specify layout

* rename:conv -> sparse_conv3d
上级 61443a0e
......@@ -106,7 +106,7 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices,
const bool coalesced) {
this->non_zero_indices_ = non_zero_indices;
this->non_zero_elements_ = non_zero_elements;
this->dims_ = dims_;
this->dims_ = dims;
this->coalesced_ = coalesced;
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
namespace sparse {
struct Dims4D {
int dims[4];
Dims4D(const int batch, const int x, const int y, const int z) {
dims[0] = batch;
dims[1] = z;
dims[2] = y;
dims[3] = x;
}
HOSTDEVICE const int& operator[](int i) const { return dims[i]; }
};
// Judge whether the current position x is in (lower, upper)
inline HOSTDEVICE bool Check(const int& x,
const int& kx,
const int& pad,
const int& stride,
const int dilation,
const int kdim,
const int xdim) {
const int lower = x - dilation * kx + pad;
const int uper = x + (kdim - kx - 1) * dilation - pad;
return (lower >= 0 && lower % stride == 0 && uper < xdim);
}
// Check whether the current position(x, y, z) is legal:
// Judge the minimum and maximum values at each latitude
inline HOSTDEVICE bool Check(const Dims4D& dims,
const Dims4D& kernel_dims,
const Dims4D& paddings,
const Dims4D& dilations,
const Dims4D& strides,
const int x,
const int y,
const int z,
const int kx,
const int ky,
const int kz) {
bool x_valid = Check(
x, kx, paddings[3], strides[3], dilations[3], kernel_dims[3], dims[3]);
bool y_valid = Check(
y, ky, paddings[2], strides[2], dilations[2], kernel_dims[2], dims[2]);
bool z_valid = Check(
z, kz, paddings[1], strides[1], dilations[1], kernel_dims[1], dims[1]);
return (x_valid && y_valid && z_valid);
}
template <typename Dim>
inline HOSTDEVICE int PointToIndex(const int& batch,
const int& x,
const int& y,
const int& z,
const Dim& dims) {
return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] +
y * dims[3] + x;
}
template <typename Dim>
inline HOSTDEVICE void IndexToPoint(
const int index, const Dim& dims, int* batch, int* x, int* y, int* z) {
int n = index;
*x = n % dims[3];
n /= dims[3];
*y = n % dims[2];
n /= dims[2];
*z = n % dims[1];
n /= dims[1];
*batch = n;
}
inline void GetOutShape(const DDim& x_dims,
const DDim& kernel_dims,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
DDim* out_dims) {
PADDLE_ENFORCE_EQ(
x_dims.size(),
5,
phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)"));
PADDLE_ENFORCE_EQ(kernel_dims.size(),
5,
phi::errors::InvalidArgument(
"the shape of kernel should be (D, H, W, C, OC)"));
// infer out shape
(*out_dims)[0] = x_dims[0];
(*out_dims)[4] = kernel_dims[4];
for (int i = 1; i < 4; i++) {
(*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] -
dilations[i - 1] * (kernel_dims[i - 1] - 1) - 1) /
strides[i - 1] +
1;
}
}
template <typename T, typename Context>
void Conv3dKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
SparseCooTensor* out,
DenseTensor* rulebook);
template <typename T, typename Context>
SparseCooTensor Conv3d(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
DenseTensor* rulebook) {
DenseTensor indices = phi::Empty<T, Context>(dev_ctx);
DenseTensor values = phi::Empty<T, Context>(dev_ctx);
SparseCooTensor coo(indices, values, x.dims());
Conv3dKernel<T, Context>(
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);
return coo;
}
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <set>
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
namespace sparse {
// such as: kernel(3, 3, 3), kernel_size = 27
// counter_per_weight: (kernel_size)
// TODO(zhangkaihuo): optimize performance with multithreading
template <typename T, typename Context>
void ProductRuleBook(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const DDim& out_dims,
DenseTensor* rulebook,
DenseTensor* counter_per_kernel) {
const auto& kernel_dims = kernel.dims();
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
dev_ctx.Alloc(counter_per_kernel,
counter_per_kernel->dtype(),
sizeof(int) * counter_per_kernel->numel());
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
memset(counter_ptr, 0, kernel_size * sizeof(int));
int rulebook_len = 0;
// calc the rulebook_len
const auto& x_dims = x.dims();
const Dims4D c_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]);
const Dims4D c_kernel_dims(1, kernel_dims[2], kernel_dims[1], kernel_dims[0]);
const Dims4D c_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]);
const Dims4D c_paddings(1, paddings[2], paddings[1], paddings[0]);
const Dims4D c_strides(1, strides[2], strides[1], strides[0]);
const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]);
auto f_calc_rulebook = [&](int* rulebook_ptr) {
int kernel_index = 0, rulebook_index = 0;
for (int kz = 0; kz < kernel_dims[0]; kz++) {
for (int ky = 0; ky < kernel_dims[1]; ky++) {
for (int kx = 0; kx < kernel_dims[2]; kx++) {
for (int64_t i = 0; i < non_zero_num; i++) {
int batch = indices_ptr[i];
int in_z = indices_ptr[i + non_zero_num];
int in_y = indices_ptr[i + 2 * non_zero_num];
int in_x = indices_ptr[i + 3 * non_zero_num];
int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0];
int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1];
int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2];
if (Check(c_x_dims,
c_kernel_dims,
c_paddings,
c_dilations,
c_strides,
in_x,
in_y,
in_z,
kx,
ky,
kz)) {
if (rulebook_ptr == nullptr) {
counter_ptr[kernel_index] += 1;
++rulebook_len;
} else {
rulebook_ptr[rulebook_index] = kernel_index;
rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i
rulebook_ptr[rulebook_index + rulebook_len * 2] =
PointToIndex<DDim>(
batch, out_x, out_y, out_z, out_dims); // out_index
++rulebook_index;
}
}
}
++kernel_index;
}
}
}
};
f_calc_rulebook(nullptr);
// alloc the rulebook
rulebook->ResizeAndAllocate({3, rulebook_len});
dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int));
int* rulebook_ptr = rulebook->data<int>();
f_calc_rulebook(rulebook_ptr);
}
template <typename T, typename Context>
void UpdateRulebookAndOutIndex(const Context& dev_ctx,
const SparseCooTensor& x,
const int kernel_size,
const int out_channels,
const DDim& out_dims,
DenseTensor* rulebook,
SparseCooTensor* out) {
std::set<int> out_indexs;
int n = rulebook->dims()[1];
int* rulebook_ptr = rulebook->data<int>();
for (int i = 0; i < n; i++) {
out_indexs.insert(rulebook_ptr[i + n * 2]);
}
int out_non_zero_num = out_indexs.size();
const int64_t sparse_dim = 4;
DenseTensorMeta indices_meta(
DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(
x.dtype(), {out_non_zero_num, out_channels}, x.layout());
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
dev_ctx.Alloc(
&out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int));
int* out_indices_ptr = out_indices.data<int>();
int i = 0;
for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {
const int index = *it;
int batch, x, y, z;
IndexToPoint<DDim>(index, out_dims, &batch, &x, &y, &z);
out_indices_ptr[i] = batch;
out_indices_ptr[i + out_non_zero_num] = z;
out_indices_ptr[i + out_non_zero_num * 2] = y;
out_indices_ptr[i + out_non_zero_num * 3] = x;
}
for (i = 0; i < n; i++) {
int out_index = rulebook_ptr[i + n * 2];
rulebook_ptr[i + n * 2] =
std::distance(out_indexs.begin(), out_indexs.find(out_index));
}
out->SetMember(out_indices, out_values, out_dims, true);
}
template <typename T>
void Gather(
const T* x, const int* indexs, const int n, const int channels, T* out) {
for (int i = 0; i < n; i++) {
int real_i = indexs[i];
memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T));
}
}
template <typename T>
void Scatter(
const T* x, const int* indexs, const int n, const int channels, T* out) {
for (int i = 0; i < n; i++) {
int real_i = indexs[i];
for (int j = 0; j < channels; j++) {
out[real_i * channels + j] += x[i * channels + j];
}
}
}
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
namespace phi {
namespace sparse {
/**
* x: (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
* out: (N, D, H, W, OC)
**/
template <typename T, typename Context>
void Conv3dKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
SparseCooTensor* out,
DenseTensor* rulebook) {
// update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight)
const auto& x_dims = x.dims();
const auto& kernel_dims = kernel.dims();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
DDim out_dims = {1, 1, 1, 1, 1};
GetOutShape(x_dims, kernel_dims, paddings, dilations, strides, &out_dims);
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
// DenseTensor rulebook = phi::Empty<int, Context>(dev_ctx);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
ProductRuleBook<T, Context>(dev_ctx,
x,
kernel,
paddings,
dilations,
strides,
out_dims,
rulebook,
&counter_per_kernel);
UpdateRulebookAndOutIndex<T>(
dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out);
int n = rulebook->dims()[1];
const int* counter_ptr = counter_per_kernel.data<int>();
// 2. gather
DenseTensorMeta in_features_meta(
x.dtype(), {n, in_channels}, DataLayout::NHWC);
DenseTensorMeta out_features_meta(
x.dtype(), {n, out_channels}, DataLayout::NHWC);
phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta));
dev_ctx.Alloc(&in_features, x.dtype(), sizeof(T) * in_features.numel());
dev_ctx.Alloc(&out_features, x.dtype(), sizeof(T) * out_features.numel());
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
Gather<T>(x.non_zero_elements().data<T>(),
rulebook->data<int>() + n,
n,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
std::vector<int> offsets(kernel_size + 1);
int offset = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter_ptr[i];
}
offsets[kernel_size] = offset;
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0) {
continue;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = counter_ptr[i];
const int K = in_channels; // in_channels
const int N = out_channels; // out_channels
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter
dev_ctx.Alloc(out->mutable_non_zero_elements(),
out->mutable_non_zero_elements()->dtype(),
sizeof(T) * in_features.numel());
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
Scatter<T>(out_features_ptr,
rulebook->data<int>() + n * 2,
n,
out_channels,
out_values_ptr);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(
sparse_conv3d, CPU, ALL_LAYOUT, phi::sparse::Conv3dKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -86,19 +86,6 @@ __global__ void GetNonZeroElementsAndIndices(const T* dense_data,
}
}
template <typename Context>
void GetGpuLaunchConfig1D(const Context& dev_ctx,
const int64_t n,
int* grid_size,
int* block_size) {
const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
*block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(n)));
*grid_size = n / *block_size;
*grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size;
}
template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -40,6 +40,19 @@ inline const DDim InferDenseDims(const DDim& x_dims,
return values_dims;
}
template <typename Context>
inline void GetGpuLaunchConfig1D(const Context& dev_ctx,
const int64_t n,
int* grid_size,
int* block_size) {
const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
*block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(n)));
*grid_size = n / *block_size;
*grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size;
}
template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -13,6 +13,7 @@ cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_math_function SRCS test_math_function.cc DEPS math_function)
if(WITH_GPU)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace tests {
std::vector<int> flatten(const std::vector<std::vector<int>>& in) {
std::vector<int> out;
if (in.size() == 0) return out;
const int cols = in[0].size();
out.resize(in.size() * cols);
for (uint64_t i = 0; i < in.size(); i++) {
memcpy(&out[i * cols], in[i].data(), cols * sizeof(int));
}
return out;
}
template <typename T1, typename T2>
std::vector<T2> cast(const std::vector<T1>& in) {
std::vector<T2> out(in.size());
for (uint64_t i = 0; i < in.size(); i++) {
out[i] = static_cast<T2>(in[i]);
}
return out;
}
template <typename T>
void TestConv3dBase(const std::vector<int>& indices,
const std::vector<T>& features,
const DDim& x_dims,
const std::vector<T>& kernel,
const DDim& kernel_dims,
const std::vector<int>& correct_out_indices,
const std::vector<T>& correct_out_features,
const DDim& correct_out_dims,
const int non_zero_num,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const float diff = 1e-3) {
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.Init();
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
DenseTensor indices_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW));
dev_ctx_cpu.Alloc(&indices_tensor,
indices_tensor.dtype(),
sizeof(int) * indices_tensor.numel());
memcpy(
indices_tensor.data<int>(), indices.data(), indices.size() * sizeof(int));
DenseTensor features_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num, in_channels},
DataLayout::NHWC));
dev_ctx_cpu.Alloc(&features_tensor,
features_tensor.dtype(),
features_tensor.numel() * sizeof(T));
memcpy(
features_tensor.data<T>(), features.data(), features.size() * sizeof(T));
SparseCooTensor x_tensor(indices_tensor, features_tensor, x_dims);
DenseTensor kernel_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
kernel_dims,
DataLayout::NHWC));
dev_ctx_cpu.Alloc(
&kernel_tensor, kernel_tensor.dtype(), kernel_tensor.numel() * sizeof(T));
memcpy(kernel_tensor.data<T>(), kernel.data(), kernel.size() * sizeof(T));
if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook = phi::Empty<int, phi::CPUContext>(dev_ctx_cpu);
SparseCooTensor out = sparse::Conv3d<T>(dev_ctx_cpu,
x_tensor,
kernel_tensor,
paddings,
dilations,
strides,
1,
&rulebook);
ASSERT_EQ(correct_out_dims.size(), out.dims().size());
for (int i = 0; i < correct_out_dims.size(); i++) {
ASSERT_EQ(correct_out_dims[i], out.dims()[i]);
}
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz());
int cmp_indices = memcmp(correct_out_indices.data(),
out.non_zero_indices().data<int>(),
correct_out_indices.size() * sizeof(int));
ASSERT_EQ(cmp_indices, 0);
for (uint64_t i = 0; i < correct_out_features.size(); i++) {
float tmp = std::fabs(static_cast<float>(
correct_out_features[i] - out.non_zero_elements().data<T>()[i]));
ASSERT_LT(tmp, diff);
}
}
}
void TestConv3d(const std::vector<int>& indices,
const std::vector<float>& features,
const DDim& x_dims,
const std::vector<float>& kernel,
const DDim& kernel_dims,
const std::vector<int>& correct_out_indices,
const std::vector<float>& correct_out_features,
const DDim& correct_out_dims,
const int non_zero_num,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations) {
// test float
TestConv3dBase<float>(indices,
features,
x_dims,
kernel,
kernel_dims,
correct_out_indices,
correct_out_features,
correct_out_dims,
non_zero_num,
paddings,
strides,
dilations);
// test double
TestConv3dBase<double>(indices,
cast<float, double>(features),
x_dims,
cast<float, double>(kernel),
kernel_dims,
correct_out_indices,
cast<float, double>(correct_out_features),
correct_out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv3d) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 4, 4, 4, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 2, 2, 2, out_channels};
std::vector<int> paddings = {0, 0, 0};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 4;
std::vector<std::vector<int>> indices = {
{0, 0, 0, 0}, {0, 2, 0, 2}, {3, 2, 2, 3}, {3, 2, 3, 2}};
std::vector<int> indices_flatten = flatten(indices);
std::vector<float> features = {-0.2883, 0.0287, 0.2864, -0.0992};
// 3*3*3=27
std::vector<float> kernel = {
0.4721, 0.2292, 0.9751, 0.8616, 0.5784, 0.9178, 0.8727, 0.1659, 0.4455,
0.0189, 0.4646, 0.4472, 0.1991, 0.8968, 0.3717, 0.0051, 0.6963, 0.2690,
0.7473, 0.5403, 0.5391, 0.0796, 0.4734, 0.9097, 0.1712, 0.6237, 0.8837};
std::vector<std::vector<int>> out_indices = {{0, 0, 0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 1, 1, 1, 1},
{0, 0, 1, 1, 0, 0, 1, 1},
{0, 1, 0, 1, 0, 1, 0, 1}};
std::vector<int> out_indices_flatten = flatten(out_indices);
std::vector<float> out_features = {
0.0254, 0.1455, -0.0615, 0.0862, 0.0077, 0.0200, -0.0160, -0.0433};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv3d_batch) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {2, 4, 4, 4, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {2, 2, 2, 2, out_channels};
std::vector<int> paddings = {0, 0, 0};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 8;
std::vector<std::vector<int>> indices = {{0, 0, 0, 0, 1, 1, 1, 1},
{0, 2, 0, 2, 0, 2, 0, 2},
{3, 2, 2, 3, 3, 2, 2, 3},
{3, 2, 3, 2, 3, 2, 3, 2}};
std::vector<int> indices_flatten = flatten(indices);
std::vector<float> features = {
-0.2883, 0.0287, 0.2864, -0.0992, -0.2883, 0.0287, 0.2864, -0.0992};
// 3*3*3=27
std::vector<float> kernel = {
0.4721, 0.2292, 0.9751, 0.8616, 0.5784, 0.9178, 0.8727, 0.1659, 0.4455,
0.0189, 0.4646, 0.4472, 0.1991, 0.8968, 0.3717, 0.0051, 0.6963, 0.2690,
0.7473, 0.5403, 0.5391, 0.0796, 0.4734, 0.9097, 0.1712, 0.6237, 0.8837};
std::vector<std::vector<int>> out_indices = {
{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1},
{0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}};
std::vector<int> out_indices_flatten = flatten(out_indices);
std::vector<float> out_features = {0.0254,
0.1455,
-0.0615,
0.0862,
0.0077,
0.0200,
-0.0160,
-0.0433,
0.0254,
0.1455,
-0.0615,
0.0862,
0.0077,
0.0200,
-0.0160,
-0.0433};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv3d_stride) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 4, 4, 4, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 1, 1, 1, out_channels};
std::vector<int> paddings = {0, 0, 0};
std::vector<int> strides = {2, 2, 2};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 3;
std::vector<std::vector<int>> indices = {
{0, 0, 0}, {0, 2, 0}, {3, 2, 2}, {3, 2, 3}};
std::vector<int> indices_flatten = flatten(indices);
std::vector<float> features = {-0.28833008, 0.02873230, 0.28637695};
// 3*3*3=27
std::vector<float> kernel = {
0.45043945, 0.47216797, 0.22924805, 0.97509766, 0.86181641, 0.57861328,
0.91796875, 0.87255859, 0.16589355, 0.44555664, 0.01889038, 0.46459961,
0.44726562, 0.19909668, 0.89697266, 0.37158203, 0.00513077, 0.69628906,
0.26904297, 0.74707031, 0.54003906, 0.5390625, 0.07958984, 0.47338867,
0.90966797, 0.17126465, 0.62353516};
std::vector<std::vector<int>> out_indices = {{0, 0, 0, 0}};
std::vector<int> out_indices_flatten = flatten(out_indices);
std::vector<float> out_features = {0.01791};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv3d_dilation) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 6, 6, 6, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 2, 2, 2, out_channels};
std::vector<int> paddings = {0, 0, 0};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {2, 2, 2};
const int non_zero_num = 3;
std::vector<std::vector<int>> indices = {
{0, 0, 0}, {2, 3, 3}, {2, 3, 3}, {5, 2, 0}};
std::vector<int> indices_flatten = flatten(indices);
std::vector<float> features = {-0.78710938, -0.64746094, 0.98828125};
// 3*3*3=27
std::vector<float> kernel = {
0.20617676, 0.99365234, 0.16760254, 0.30639648, 0.41479492, 0.75732422,
0.65625, 0.48535156, 0.72167969, 0.56005859, 0.5, 0.3581543,
0.20324707, 0.88769531, 0.81298828, 0.58398438, 0.30810547, 0.12634277,
0.70507812, 0.38720703, 0.34814453, 0.02690125, 0.80273438, 0.90625,
0.2277832, 0.4362793, 0.44482422};
std::vector<std::vector<int>> out_indices = {{0, 0, 0, 1, 0, 1, 1, 0}};
std::vector<int> out_indices_flatten = flatten(out_indices);
std::vector<float> out_features = {-0.64014, -0.37402};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv3d_padding) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 3, 3, 3, in_channels};
DDim kernel_dims = {3, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 3, 3, 3, out_channels};
std::vector<int> paddings = {1, 1, 1};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 1;
std::vector<std::vector<int>> indices = {{0, 1, 0, 0}};
std::vector<int> indices_flatten = flatten(indices);
std::vector<float> features = {-0.79394531};
// 3*3*3=27
std::vector<float> kernel = {
0.34375, 0.22485352, 0.65820312, 0.75048828, 0.21411133, 0.17370605,
0.85546875, 0.53076172, 0.28833008, 0.71044922, 0.00659943, 0.45922852,
0.19372559, 0.64599609, 0.78808594, 0.49316406, 0.62646484, 0.40649414,
0.62744141, 0.5703125, 0.23144531, 0.50048828, 0.31835938, 0.90869141,
0.38208008, 0.60449219, 0.09075928};
std::vector<int> out_indices_flatten = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
std::vector<float> out_features = {-0.25269,
-0.39746,
-0.45288,
-0.49805,
-0.5127,
-0.15381,
-0.00524,
-0.56396,
-0.17004,
-0.5957,
-0.17847,
-0.27295};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
TEST(DEV_API, sparse_conv2d) {
const int in_channels = 1;
const int out_channels = 1;
DDim x_dims = {1, 1, 5, 5, in_channels};
DDim kernel_dims = {1, 3, 3, in_channels, out_channels};
DDim out_dims = {1, 1, 3, 3, out_channels};
std::vector<int> paddings = {0, 0, 0};
std::vector<int> strides = {1, 1, 1};
std::vector<int> dilations = {1, 1, 1};
const int non_zero_num = 3;
std::vector<int> indices_flatten = {0, 0, 0, 0, 0, 0, 0, 4, 0, 3, 2, 4};
std::vector<float> features = {-0.79394531, -0.3125, -0.55029297};
// 3*3*3=27
std::vector<float> kernel = {0.65820312,
0.75048828,
0.21411133,
0.17370605,
0.85546875,
0.53076172,
0.28833008,
0.71044922,
0.00659943};
std::vector<int> out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 2, 2, 2, 1, 2, 0, 1, 2};
std::vector<float> out_features = {
-0.17004, -0.71338, -0.00206, -0.22205, -0.09009};
TestConv3d(indices_flatten,
features,
x_dims,
kernel,
kernel_dims,
out_indices_flatten,
out_features,
out_dims,
non_zero_num,
paddings,
strides,
dilations);
}
} // namespace tests
} // namespace phi
......@@ -8,7 +8,7 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF NCHW KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册