未验证 提交 6bf85eaf 编写于 作者: Z zhangkaihuo 提交者: GitHub

Implement SparseConv3d kernel (#39784)

* sparse conv3d: gpu code
上级 71c69507
......@@ -145,6 +145,7 @@ class SparseCooTensor : public TensorBase,
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
void set_dims(const DDim& dims) { this->dims_ = dims; }
private:
// save the indices of non zero elements in original dense tensor
......
set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function)
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel")
......@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
namespace phi {
namespace sparse {
......@@ -55,7 +54,6 @@ void Conv3dKernel(const Context& dev_ctx,
// 1. product rulebook
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
// DenseTensor rulebook = phi::Empty<int, Context>(dev_ctx);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
ProductRuleBook<T, Context>(dev_ctx,
......
此差异已折叠。
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
......@@ -151,6 +152,107 @@ void TestConv3dBase(const std::vector<int>& indices,
f_verify(grads[1].data<T>(), kernel_grad);
}
}
// test gpu
#if defined(PADDLE_WITH_CUDA)
phi::GPUContext dev_ctx_gpu;
dev_ctx_gpu.PartialInitWithoutAllocator();
dev_ctx_gpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream())
.get());
dev_ctx_gpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator();
DenseTensor d_indices_tensor = phi::Empty(
dev_ctx_gpu,
DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW));
dev_ctx_gpu.Alloc(&d_indices_tensor,
d_indices_tensor.dtype(),
sizeof(int) * d_indices_tensor.numel());
phi::Copy(
dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor);
DenseTensor d_features_tensor = phi::Empty(
dev_ctx_gpu,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num, in_channels},
DataLayout::NHWC));
dev_ctx_gpu.Alloc(&d_features_tensor,
d_features_tensor.dtype(),
sizeof(T) * d_features_tensor.numel());
phi::Copy(
dev_ctx_gpu, features_tensor, phi::GPUPlace(), true, &d_features_tensor);
SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims);
DenseTensor d_kernel_tensor = phi::Empty(
dev_ctx_gpu,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
kernel_dims,
DataLayout::NHWC));
dev_ctx_gpu.Alloc(&d_kernel_tensor,
d_kernel_tensor.dtype(),
sizeof(T) * d_kernel_tensor.numel());
phi::Copy(
dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor);
DenseTensor d_rulebook = phi::Empty<int, phi::GPUContext>(dev_ctx_gpu);
SparseCooTensor d_out = sparse::Conv3d<T>(dev_ctx_gpu,
d_x_tensor,
d_kernel_tensor,
paddings,
dilations,
strides,
1,
&d_rulebook);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz());
for (int i = 0; i < correct_out_dims.size(); i++) {
ASSERT_EQ(correct_out_dims[i], d_out.dims()[i]);
}
DenseTensor h_indices_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW));
dev_ctx_cpu.Alloc(&h_indices_tensor,
h_indices_tensor.dtype(),
sizeof(int) * h_indices_tensor.numel());
phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(),
phi::CPUPlace(),
true,
&h_indices_tensor);
int cmp_indices2 = memcmp(correct_out_indices.data(),
h_indices_tensor.data<int>(),
correct_out_indices.size() * sizeof(int));
ASSERT_EQ(cmp_indices2, 0);
DenseTensor h_features_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
{d_out.nnz()},
d_out.layout()));
dev_ctx_cpu.Alloc(&h_features_tensor,
h_features_tensor.dtype(),
sizeof(T) * h_features_tensor.numel());
phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(),
phi::CPUPlace(),
true,
&h_features_tensor);
for (uint64_t i = 0; i < correct_out_features.size(); i++) {
float tmp = std::fabs(static_cast<float>(correct_out_features[i] -
h_features_tensor.data<T>()[i]));
ASSERT_LT(tmp, diff);
}
#endif
}
void TestConv3d(const std::vector<int>& indices,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册