未验证 提交 60b86b2f 编写于 作者: Z zhangkaihuo 提交者: GitHub

Sparse Conv3d gpu backward (#40143)

Sparse conv3d backward(gpu)
上级 3e9601ba
......@@ -45,8 +45,10 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups) {
DenseTensor x_grad = phi::Empty<T, Context>(dev_ctx);
DenseTensor kernel_grad = phi::Empty<T, Context>(dev_ctx);
DenseTensor x_grad =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
DenseTensor kernel_grad = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout()));
// TODO(zhangkaihuo): call InferMeta func here
Conv3dGradKernel<T, Context>(dev_ctx,
x,
......
......@@ -20,18 +20,6 @@ limitations under the License. */
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
phi::DenseTensor dense_out(
phi::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
return dense_out;
}
namespace sparse {
struct Dims4D {
......@@ -149,8 +137,10 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
DenseTensor* rulebook) {
DenseTensor indices = phi::Empty<T, Context>(dev_ctx);
DenseTensor values = phi::Empty<T, Context>(dev_ctx);
DenseTensor indices = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor values =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
SparseCooTensor coo(indices, values, x.dims());
Conv3dKernel<T, Context>(
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);
......
......@@ -45,9 +45,6 @@ void ProductRuleBook(const Context& dev_ctx,
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
dev_ctx.Alloc(counter_per_kernel,
counter_per_kernel->dtype(),
sizeof(int) * counter_per_kernel->numel());
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
memset(counter_ptr, 0, kernel_size * sizeof(int));
......@@ -138,8 +135,6 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
x.dtype(), {out_non_zero_num, out_channels}, x.layout());
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
dev_ctx.Alloc(
&out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int));
int* out_indices_ptr = out_indices.data<int>();
int i = 0;
for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
namespace phi {
......@@ -60,15 +61,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
phi::DenseTensor out_grad_features =
phi::Empty(dev_ctx, std::move(out_grad_features_meta));
dev_ctx.Alloc(
&in_features, in_features.dtype(), sizeof(T) * in_features.numel());
T* in_features_ptr = in_features.data<T>();
dev_ctx.Alloc(
&d_x_features, d_x_features.dtype(), sizeof(T) * d_x_features.numel());
T* d_x_features_ptr = d_x_features.data<T>();
dev_ctx.Alloc(&out_grad_features,
out_grad_features.dtype(),
sizeof(T) * out_grad_features.numel());
T* out_grad_features_ptr = out_grad_features.data<T>();
kernel_grad->Resize(kernel_dims);
dev_ctx.Alloc(
......@@ -156,12 +150,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sparse_conv_grad,
PD_REGISTER_KERNEL(sparse_conv3d_grad,
CPU,
ALL_LAYOUT,
phi::sparse::Conv3dGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -81,8 +81,6 @@ void Conv3dKernel(const Context& dev_ctx,
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta));
dev_ctx.Alloc(&in_features, x.dtype(), sizeof(T) * in_features.numel());
dev_ctx.Alloc(&out_features, x.dtype(), sizeof(T) * out_features.numel());
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
......@@ -128,9 +126,6 @@ void Conv3dKernel(const Context& dev_ctx,
}
// 4. scatter
dev_ctx.Alloc(out->mutable_non_zero_elements(),
out->mutable_non_zero_elements()->dtype(),
sizeof(T) * in_features.numel());
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
Scatter<T>(out_features_ptr,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
namespace phi {
namespace sparse {
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
/**
* brief: gather data from params according to indices
* params: the inputs
* indices: the indices you want to gather
* output: the outputs
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template <typename T, typename IndexT = int>
__global__ void GatherKernel(const T* params,
const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
int64_t params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template <typename T>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
}
out[indices_i * channels + channels_i] = sum;
}
}
template <typename Context>
inline int* SortedAndUniqueIndex(const Context& dev_ctx,
const int* rulebook_ptr,
const int len,
DenseTensor* out_index,
DenseTensor* unique_key,
DenseTensor* unique_value) {
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, out_index, kps::IdentityFunctor<int>());
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, unique_value, kps::IdentityFunctor<int>());
phi::backends::gpu::GpuMemcpyAsync(unique_key->data<int>(),
rulebook_ptr,
sizeof(int) * len,
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToDevice,
#else
cudaMemcpyDeviceToDevice,
#endif
dev_ctx.stream());
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
// performance, but thrust::merge_by_key limited by data size
#ifdef PADDLE_WITH_HIP
thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key->data<int>(),
unique_key->data<int>() + len,
out_index->data<int>());
// 4. unique
thrust::pair<int*, int*> new_end =
#ifdef PADDLE_WITH_HIP
thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key->data<int>(),
unique_key->data<int>() + len,
unique_value->data<int>());
return new_end.first;
}
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace phi {
namespace sparse {
// rulebook[3, rulebook_len]:
//[
// [kernel_index],
// [in_i],
// [out_i],
//]
// x_grad = out_grad * transpose(kenrel)
// kernel_grad = transpose(x) * out_grad
template <typename T, typename Context>
void Conv3dGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& kernel,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
DenseTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
const int* rulebook_ptr = rulebook.data<int>();
const int rulebook_len = rulebook.dims()[1];
DenseTensorMeta in_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta d_x_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta out_grad_features_meta(
x.dtype(), {rulebook_len, out_channels}, DataLayout::NCHW);
phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::DenseTensor d_x_features =
phi::Empty(dev_ctx, std::move(d_x_features_meta));
phi::DenseTensor out_grad_features =
phi::Empty(dev_ctx, std::move(out_grad_features_meta));
T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>();
kernel_grad->Resize(kernel_dims);
dev_ctx.Alloc(
kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T));
T* d_kernel_ptr = kernel_grad->data<T>();
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, kernel_grad, static_cast<T>(0.0f));
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
rulebook_ptr,
rulebook_len * sizeof(int),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < rulebook_len; i++) {
counter[h_counter[i]] += 1;
}
int offset = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter[i];
}
offsets[kernel_size] = offset;
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) {
continue;
}
const int M = counter[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas.GEMM(CblasNoTrans,
CblasTrans,
M,
K,
N,
static_cast<T>(1),
tmp_out_grad_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_d_x_ptr);
}
// 4. scatter
x_grad->Resize(x.non_zero_elements().dims());
dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel());
T* x_grad_values_ptr = x_grad->data<T>();
DenseTensor out_index = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
DenseTensor unique_key = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW));
SortedAndUniqueIndex(dev_ctx,
rulebook_ptr + rulebook_len,
rulebook_len,
&out_index,
&unique_key,
&unique_value);
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
ScatterKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(d_x_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
x.nnz(),
rulebook_len,
in_channels,
x_grad_values_ptr);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sparse_conv3d_grad,
GPU,
ALL_LAYOUT,
phi::sparse::Conv3dGradKernel,
float,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "glog/logging.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
......@@ -28,19 +27,11 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace phi {
namespace sparse {
// TODO(zhangkaihuo) replace this kernel with KP::InitWithDataIndex
__global__ void InitByIndexKernel(const int n, int* out1, int* out2) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < n; i += gridDim.x * blockDim.x) {
out1[i] = i;
out2[i] = i;
}
}
/**
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
......@@ -124,7 +115,7 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
int in_z = x_indices[i + non_zero_num];
int in_y = x_indices[i + 2 * non_zero_num];
int in_x = x_indices[i + 3 * non_zero_num];
int in_i = -1, out_index = -1;
int in_i = -1, out_index = -1, kernel_i = -1;
if (Check(x_dims,
kernel_dims,
paddings,
......@@ -143,9 +134,11 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
out_index =
PointToIndex<Dims4D>(batch, out_x, out_y, out_z, out_dims);
atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index;
}
rulebook[kernel_index * non_zero_num + i] = in_i;
rulebook[kernel_index * non_zero_num + offset + i] = out_index;
rulebook[kernel_index * non_zero_num + i] = kernel_i;
rulebook[kernel_index * non_zero_num + offset + i] = in_i;
rulebook[kernel_index * non_zero_num + offset * 2 + i] = out_index;
++kernel_index;
}
}
......@@ -157,68 +150,6 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
}
}
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
/**
* brief: gather data from params according to indices
* params: the inputs
* indices: the indices you want to gather
* output: the outputs
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template <typename T, typename IndexT = int>
__global__ void GatherKernel(const T* params,
const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
int64_t params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template <typename T>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
}
out[indices_i * channels + channels_i] = sum;
}
}
// brief: calculation the distance between start and end
__global__ void DistanceKernel(const int* start,
const int* end,
......@@ -264,16 +195,12 @@ int ProductRuleBook(const Context& dev_ctx,
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
dev_ctx.Alloc(counter_per_kernel,
counter_per_kernel->dtype(),
sizeof(int) * counter_per_kernel->numel());
int* counter_ptr = counter_per_kernel->data<int>();
dev_ctx.Alloc(offsets_per_kernel,
offsets_per_kernel->dtype(),
sizeof(int) * offsets_per_kernel->numel());
int* offsets_ptr = offsets_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
rulebook->ResizeAndAllocate({2, kernel_size * non_zero_num});
const int rulebook_rows = 3;
const int rulebook_cols = kernel_size * non_zero_num;
rulebook->ResizeAndAllocate({rulebook_rows, rulebook_cols});
dev_ctx.Alloc(rulebook, rulebook->dtype(), sizeof(int) * rulebook->numel());
int* rulebook_ptr = rulebook->data<int>();
......@@ -312,7 +239,7 @@ int ProductRuleBook(const Context& dev_ctx,
int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
rulebook_ptr,
rulebook_ptr + 2 * kernel_size * non_zero_num,
rulebook_ptr + rulebook_rows * rulebook_cols,
-1);
#ifdef PADDLE_WITH_HIP
......@@ -350,6 +277,7 @@ int ProductRuleBook(const Context& dev_ctx,
dev_ctx.Wait();
int rulebook_len =
(*h_counter)[kernel_size - 1] + (*h_offsets)[kernel_size - 1];
rulebook->Resize({rulebook_rows, rulebook_len});
// 3. sorted or merge the out index
out_index->ResizeAndAllocate({rulebook_len});
......@@ -365,66 +293,30 @@ int ProductRuleBook(const Context& dev_ctx,
unique_key, unique_key->dtype(), sizeof(int) * unique_key->numel());
int* unique_key_ptr = unique_key->data<int>();
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
InitByIndexKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
rulebook_len, out_index_ptr, unique_value_ptr);
#ifdef PADDLE_WITH_HIP
phi::backends::gpu::GpuMemcpyAsync(unique_key_ptr,
rulebook_ptr + rulebook_len,
rulebook_len * sizeof(int),
hipMemcpyDeviceToDevice,
dev_ctx.stream());
#else
phi::backends::gpu::GpuMemcpyAsync(unique_key_ptr,
rulebook_ptr + rulebook_len,
rulebook_len * sizeof(int),
cudaMemcpyDeviceToDevice,
dev_ctx.stream());
#endif
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
// performance, but thrust::merge_by_key limited by data size
#ifdef PADDLE_WITH_HIP
thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key_ptr,
unique_key_ptr + rulebook_len,
out_index_ptr);
// 4. unique
thrust::pair<int*, int*> new_end =
#ifdef PADDLE_WITH_HIP
thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key_ptr,
unique_key_ptr + rulebook_len,
unique_value_ptr);
int* new_end = SortedAndUniqueIndex(dev_ctx,
rulebook_ptr + 2 * rulebook_len,
rulebook_len,
out_index,
unique_key,
unique_value);
// thrust::distance doesn't support stream parameters
// const int out_non_zero_num = thrust::distance(unique_key_ptr,
// new_end.first);
DistanceKernel<<<1, 1>>>(unique_key_ptr,
new_end.first,
rulebook_ptr + 2 * kernel_size * non_zero_num - 1);
new_end,
rulebook_ptr + rulebook_rows * rulebook_cols - 1);
int out_non_zero_num = 0;
#ifdef PADDLE_WITH_HIP
phi::backends::gpu::GpuMemcpyAsync(
&out_non_zero_num,
rulebook_ptr + 2 * kernel_size * non_zero_num - 1,
rulebook_ptr + rulebook_rows * rulebook_cols - 1,
sizeof(int),
hipMemcpyDeviceToHost,
dev_ctx.stream());
#else
phi::backends::gpu::GpuMemcpyAsync(
&out_non_zero_num,
rulebook_ptr + 2 * kernel_size * non_zero_num - 1,
rulebook_ptr + rulebook_rows * rulebook_cols - 1,
sizeof(int),
cudaMemcpyDeviceToHost,
dev_ctx.stream());
......@@ -440,8 +332,6 @@ int ProductRuleBook(const Context& dev_ctx,
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
dev_ctx.Alloc(
&out_indices, out_indices.dtype(), sizeof(int) * out_indices.numel());
int* out_indices_ptr = out_indices.data<int>();
config =
......@@ -456,7 +346,7 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_len,
d_out_dims,
out_indices_ptr,
rulebook_ptr + rulebook_len);
rulebook_ptr + 2 * rulebook_len);
out->SetMember(out_indices, out_values, out_dims, true);
return rulebook_len;
}
......@@ -499,9 +389,12 @@ void Conv3dKernel(const Context& dev_ctx,
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta));
DenseTensor out_index = phi::Empty<int, Context>(dev_ctx);
DenseTensor unique_key = phi::Empty<int, Context>(dev_ctx);
DenseTensor unique_value = phi::Empty<int, Context>(dev_ctx);
DenseTensor out_index = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor unique_key = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
int n = ProductRuleBook<T, Context>(dev_ctx,
x,
......@@ -522,6 +415,7 @@ void Conv3dKernel(const Context& dev_ctx,
const int* counter_ptr = counter_per_kernel.data<int>();
const int* offsets_ptr = counter_per_kernel.data<int>();
const int* rulebook_ptr = rulebook->data<int>();
// 2. gather
DenseTensorMeta in_features_meta(
......@@ -532,11 +426,7 @@ void Conv3dKernel(const Context& dev_ctx,
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta));
dev_ctx.Alloc(
&in_features, in_features.dtype(), sizeof(T) * in_features.numel());
T* in_features_ptr = in_features.data<T>();
dev_ctx.Alloc(
&out_features, out_features.dtype(), sizeof(T) * out_features.numel());
T* out_features_ptr = out_features.data<T>();
auto config =
......@@ -545,7 +435,7 @@ void Conv3dKernel(const Context& dev_ctx,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook->data<int>(),
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
......@@ -553,8 +443,6 @@ void Conv3dKernel(const Context& dev_ctx,
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto* out_values = out->mutable_non_zero_elements();
dev_ctx.Alloc(
out_values, out_values->dtype(), sizeof(T) * out_values->numel());
T* out_values_ptr = out_values->data<T>();
const T* kernel_ptr = kernel.data<T>();
......
......@@ -78,9 +78,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor indices_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW));
dev_ctx_cpu.Alloc(&indices_tensor,
indices_tensor.dtype(),
sizeof(int) * indices_tensor.numel());
memcpy(
indices_tensor.data<int>(), indices.data(), indices.size() * sizeof(int));
DenseTensor features_tensor = phi::Empty(
......@@ -88,9 +85,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num, in_channels},
DataLayout::NHWC));
dev_ctx_cpu.Alloc(&features_tensor,
features_tensor.dtype(),
features_tensor.numel() * sizeof(T));
memcpy(
features_tensor.data<T>(), features.data(), features.size() * sizeof(T));
......@@ -101,12 +95,18 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
kernel_dims,
DataLayout::NHWC));
dev_ctx_cpu.Alloc(
&kernel_tensor, kernel_tensor.dtype(), kernel_tensor.numel() * sizeof(T));
memcpy(kernel_tensor.data<T>(), kernel.data(), kernel.size() * sizeof(T));
auto f_verify = [&](const T* real_data, const std::vector<T>& correct_data) {
for (uint64_t i = 0; i < correct_data.size(); i++) {
float tmp = std::fabs(static_cast<float>(correct_data[i] - real_data[i]));
ASSERT_LT(tmp, diff);
}
};
if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook = phi::Empty<int, phi::CPUContext>(dev_ctx_cpu);
DenseTensor rulebook = phi::Empty(
dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
SparseCooTensor out = sparse::Conv3d<T>(dev_ctx_cpu,
x_tensor,
kernel_tensor,
......@@ -127,15 +127,6 @@ void TestConv3dBase(const std::vector<int>& indices,
correct_out_indices.size() * sizeof(int));
ASSERT_EQ(cmp_indices, 0);
auto f_verify = [&](const T* real_data,
const std::vector<T>& correct_data) {
for (uint64_t i = 0; i < correct_data.size(); i++) {
float tmp =
std::fabs(static_cast<float>(correct_data[i] - real_data[i]));
ASSERT_LT(tmp, diff);
}
};
f_verify(out.non_zero_elements().data<T>(), correct_out_features);
if (backward) {
......@@ -170,9 +161,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor d_indices_tensor = phi::Empty(
dev_ctx_gpu,
DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW));
dev_ctx_gpu.Alloc(&d_indices_tensor,
d_indices_tensor.dtype(),
sizeof(int) * d_indices_tensor.numel());
phi::Copy(
dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor);
......@@ -181,9 +169,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num, in_channels},
DataLayout::NHWC));
dev_ctx_gpu.Alloc(&d_features_tensor,
d_features_tensor.dtype(),
sizeof(T) * d_features_tensor.numel());
phi::Copy(
dev_ctx_gpu, features_tensor, phi::GPUPlace(), true, &d_features_tensor);
......@@ -194,13 +179,11 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(),
kernel_dims,
DataLayout::NHWC));
dev_ctx_gpu.Alloc(&d_kernel_tensor,
d_kernel_tensor.dtype(),
sizeof(T) * d_kernel_tensor.numel());
phi::Copy(
dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor);
DenseTensor d_rulebook = phi::Empty<int, phi::GPUContext>(dev_ctx_gpu);
DenseTensor d_rulebook = phi::Empty(
dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
SparseCooTensor d_out = sparse::Conv3d<T>(dev_ctx_gpu,
d_x_tensor,
d_kernel_tensor,
......@@ -219,9 +202,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor h_indices_tensor = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW));
dev_ctx_cpu.Alloc(&h_indices_tensor,
h_indices_tensor.dtype(),
sizeof(int) * h_indices_tensor.numel());
phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(),
phi::CPUPlace(),
......@@ -239,18 +219,34 @@ void TestConv3dBase(const std::vector<int>& indices,
{d_out.nnz()},
d_out.layout()));
dev_ctx_cpu.Alloc(&h_features_tensor,
h_features_tensor.dtype(),
sizeof(T) * h_features_tensor.numel());
phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(),
phi::CPUPlace(),
true,
&h_features_tensor);
for (uint64_t i = 0; i < correct_out_features.size(); i++) {
float tmp = std::fabs(static_cast<float>(correct_out_features[i] -
h_features_tensor.data<T>()[i]));
ASSERT_LT(tmp, diff);
f_verify(h_features_tensor.data<T>(), correct_out_features);
if (backward) {
std::vector<DenseTensor> grads = sparse::Conv3dGrad<T>(dev_ctx_gpu,
d_x_tensor,
d_rulebook,
d_kernel_tensor,
d_out,
paddings,
dilations,
strides,
1);
DenseTensor h_features_grad = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(grads[0].dtype(), grads[0].dims(), grads[0].layout()));
phi::Copy(dev_ctx_gpu, grads[0], phi::CPUPlace(), true, &h_features_grad);
f_verify(h_features_grad.data<T>(), features_grad);
DenseTensor h_kernel_grad = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(grads[1].dtype(), grads[1].dims(), grads[1].layout()));
phi::Copy(dev_ctx_gpu, grads[1], phi::CPUPlace(), true, &h_kernel_grad);
f_verify(h_kernel_grad.data<T>(), kernel_grad);
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册