未验证 提交 9841b308 编写于 作者: Z zhangkaihuo 提交者: GitHub

Optimize sparse convolution (#43576)

上级 22342d51
...@@ -80,14 +80,14 @@ ...@@ -80,14 +80,14 @@
data_type : x data_type : x
backward : cast_grad backward : cast_grad
- api : conv3d - api : conv3d_coo
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(out), Tensor(rulebook) output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel : kernel :
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense} func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense}
layout : x layout : x
intermediate : rulebook intermediate: rulebook, counter
backward : conv3d_grad backward : conv3d_coo_grad
- api : coo_to_dense - api : coo_to_dense
args : (Tensor x) args : (Tensor x)
...@@ -352,11 +352,11 @@ ...@@ -352,11 +352,11 @@
- api: maxpool - api: maxpool
args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output : Tensor(out), Tensor(rulebook) output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel : kernel :
func : maxpool_coo{sparse_coo -> sparse_coo, dense} func : maxpool_coo{sparse_coo -> sparse_coo, dense, dense}
layout : x layout : x
intermediate : rulebook intermediate : rulebook, counter
backward : maxpool_grad backward : maxpool_grad
- api: mv - api: mv
......
...@@ -81,12 +81,12 @@ ...@@ -81,12 +81,12 @@
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
data_type : out_grad data_type : out_grad
- backward_api : conv3d_grad - backward_api : conv3d_coo_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(x_grad), Tensor(kernel_grad) output : Tensor(x_grad), Tensor(kernel_grad)
kernel : kernel :
func : conv3d_coo_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
- backward_api : coo_to_dense_grad - backward_api : coo_to_dense_grad
forward : coo_to_dense(Tensor x) -> Tensor(out) forward : coo_to_dense(Tensor x) -> Tensor(out)
...@@ -164,11 +164,11 @@ ...@@ -164,11 +164,11 @@
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo} matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
- backward_api : maxpool_grad - backward_api : maxpool_grad
forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook) forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes)
output : Tensor(x_grad) output : Tensor(x_grad)
kernel : kernel :
func : maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo} func : maxpool_coo_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo}
- backward_api : multiply_grad - backward_api : multiply_grad
forward : multiply(Tensor x, Tensor y) -> Tensor(out) forward : multiply(Tensor x, Tensor y) -> Tensor(out)
......
...@@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase, ...@@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase,
/// \brief get the dnese dim /// \brief get the dnese dim
int32_t dense_dim() const; int32_t dense_dim() const;
/// \brief query table according to key
const std::pair<DenseTensor, DenseTensor>* IndicesPairs(
const std::string& key) const {
if (indices_dict_ == nullptr) {
return nullptr;
}
const auto& iter = indices_dict_->find(key);
if (iter == indices_dict_->end()) {
return nullptr;
}
return &iter->second;
}
/// \brief save (key, indices_pairs)
void SaveIndicesPairs(
const std::string& key,
const std::pair<DenseTensor, DenseTensor>& indices_pairs) {
if (indices_dict_ == nullptr) {
indices_dict_ = std::make_shared<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>();
}
auto ret = indices_dict_->insert({key, indices_pairs});
if (ret.second == false) {
ret.first->second = indices_pairs;
}
}
/// \brief get indices_dict_
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
GetIndicesDict() const {
return indices_dict_;
}
/// \brief set indices_dict_
void SetIndicesDict(
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
indices_dict) {
indices_dict_ = indices_dict;
}
private: private:
// save the indices of non zero elements in original dense tensor // save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_; DenseTensor non_zero_indices_;
...@@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase, ...@@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase,
bool coalesced_ = false; bool coalesced_ = false;
// save the number of non zero elements in each batch // save the number of non zero elements in each batch
DDim dims_; DDim dims_;
// for submanifold conv
// SubmConv will generate a rulebook and a counter, which can be
// reused by different SubmConv.
// refer to sparse/gpu/convolution_kernel.cu.
std::shared_ptr<std::map<std::string, std::pair<DenseTensor, DenseTensor>>>
indices_dict_ = nullptr;
/* --------------------------- */ /* --------------------------- */
/* example: non zero element is scalar */ /* example: non zero element is scalar */
/* --------------------------- */ /* --------------------------- */
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi { namespace phi {
...@@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) { ...@@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) {
offsets[n] = offset; offsets[n] = offset;
} }
template <typename IntT>
inline const IntT* GetRulebookPtr(const SparseCooTensor& coo,
const DenseTensor& rulebook,
const std::string& key,
int* rulebook_len) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
const DenseTensor& tmp_rulebook = indices_pairs->first;
*rulebook_len = tmp_rulebook.dims()[1];
return tmp_rulebook.data<IntT>();
}
}
*rulebook_len = rulebook.dims()[1];
return rulebook.data<IntT>();
}
inline const int* GetCounterPtr(const SparseCooTensor& coo,
const DenseTensor& counter,
const std::string& key) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
return indices_pairs->second.data<int>();
}
}
return counter.data<int>();
}
template <typename T, typename IntT, typename Context>
inline const IntT* PrepareSubm(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DDim& out_dims,
SparseCooTensor* out,
int* counter,
int* offsets,
int* rulebook_len,
bool* need_product_rulebook) {
const auto* indices_pairs = x.IndicesPairs(key);
if (indices_pairs != nullptr) {
*need_product_rulebook = false;
const DenseTensor& rulebook = indices_pairs->first;
const int counter_size = indices_pairs->second.numel();
memcpy(
counter, indices_pairs->second.data<int>(), counter_size * sizeof(int));
out->SetIndicesDict(x.GetIndicesDict());
*rulebook_len = rulebook.dims()[1];
DenseTensor out_indices =
phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
phi::Copy(
dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices);
out->SetMember(out_indices, out_values, out_dims, false);
PrefixSum<int>(counter, offsets, counter_size);
return rulebook.data<IntT>();
}
return nullptr;
}
template <typename Context>
inline void SaveToTable(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DenseTensor& in_rulebook,
const DenseTensor& h_counter,
SparseCooTensor* out,
DenseTensor* out_rulebook,
DenseTensor* counter) {
out->SetIndicesDict(x.GetIndicesDict());
if (!key.empty()) {
out->SaveIndicesPairs(key, std::make_pair(in_rulebook, h_counter));
} else {
*out_rulebook = in_rulebook;
counter->Resize({h_counter.numel()});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, h_counter.data<int>(), h_counter.numel() * sizeof(int));
}
}
} // namespace sparse } // namespace sparse
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define VecBytes 16
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -28,33 +33,126 @@ namespace sparse { ...@@ -28,33 +33,126 @@ namespace sparse {
* channels: the output channel size * channels: the output channel size
* out: the outputs * out: the outputs
**/ **/
template <typename T> template <typename T, int VecSize>
__global__ void ScatterKernel(const T* input, __global__ void ScatterKernel(const T* input,
const int* unique_value, const int* unique_value,
const int* out_index, const int* out_index,
const int non_zero_num, const int non_zero_num,
const int rulebook_len, const int rulebook_len,
const int channels, const int channels,
T* out, T* out) {
const bool subm = false) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { const int vec_channels = channels / VecSize;
int indices_i = i / channels; using LoadT = phi::AlignedVector<T, VecSize>;
int channels_i = i - indices_i * channels; using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;
int start = unique_value[indices_i]; int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1]; : unique_value[indices_i + 1];
// max(end-start) = kernel_size // max(end-start) = kernel_size
T sum = static_cast<T>(0); StoreT sums = {static_cast<T>(0)};
if (subm) {
sum = out[indices_i * channels + channels_i];
}
for (int j = start; j < end; j++) { for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j]; const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i]; LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
} }
out[indices_i * channels + channels_i] = sum; phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}
// scatter's index has been grouped in advance
// index_counts record the count of each group
// index_groups save the index of each group
template <typename T, int VecSize>
__global__ void ScatterKernelV2(const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;
StoreT sums = {static_cast<T>(0)};
phi::Load<T, VecSize>(out + indices_i * channels + channels_i * VecSize,
&sums);
for (int it = 0; it < buffer_counts; it++) {
int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num;
for (int j = 0; j < len; j++) {
const int out_feature_i =
index_groups[indices_i * kernel_size + j + group_offset];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
}
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}
template <typename T>
void ScatterV2(const GPUContext& dev_ctx,
const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* output) {
const int VecSize = VecBytes / sizeof(T);
if (channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels / VecSize, 1);
ScatterKernelV2<T, VecSize><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels, 1);
ScatterKernelV2<T, 1><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
} }
} }
......
...@@ -25,13 +25,16 @@ template <typename T, typename Context> ...@@ -25,13 +25,16 @@ template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx, void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad); DenseTensor* kernel_grad);
...@@ -40,13 +43,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad( ...@@ -40,13 +43,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
const Context& dev_ctx, const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm) { const bool subm,
const std::string& key) {
SparseCooTensor x_grad; SparseCooTensor x_grad;
DenseTensor kernel_grad; DenseTensor kernel_grad;
...@@ -54,13 +60,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad( ...@@ -54,13 +60,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
Conv3dCooGradKernel<T, Context>(dev_ctx, Conv3dCooGradKernel<T, Context>(dev_ctx,
x, x,
kernel, kernel,
out,
rulebook, rulebook,
counter,
out_grad, out_grad,
paddings, paddings,
dilations, dilations,
strides, strides,
groups, groups,
subm, subm,
key,
&x_grad, &x_grad,
&kernel_grad); &kernel_grad);
return std::make_tuple(x_grad, kernel_grad); return std::make_tuple(x_grad, kernel_grad);
......
...@@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook); DenseTensor* rulebook,
DenseTensor* counter);
template <typename T, typename Context> template <typename T, typename Context>
SparseCooTensor Conv3dCoo(const Context& dev_ctx, SparseCooTensor Conv3dCoo(const Context& dev_ctx,
...@@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx, ...@@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
DenseTensor* rulebook) { const std::string& key,
DenseTensor* rulebook,
DenseTensor* counter) {
SparseCooTensor coo; SparseCooTensor coo;
Conv3dCooKernel<T, Context>(dev_ctx, Conv3dCooKernel<T, Context>(dev_ctx,
x, x,
...@@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx, ...@@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
strides, strides,
groups, groups,
subm, subm,
key,
&coo, &coo,
rulebook); rulebook,
counter);
return coo; return coo;
} }
......
...@@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx, ...@@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx,
const DDim& out_dims, const DDim& out_dims,
const bool subm, const bool subm,
DenseTensor* rulebook, DenseTensor* rulebook,
DenseTensor* counter_per_kernel) { int* counter_per_kernel) {
const int64_t non_zero_num = x.nnz(); const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices(); const auto& non_zero_indices = x.non_zero_indices();
const IntT* indices_ptr = non_zero_indices.data<IntT>(); const IntT* indices_ptr = non_zero_indices.data<IntT>();
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
memset(counter_ptr, 0, kernel_size * sizeof(int)); memset(counter_per_kernel, 0, kernel_size * sizeof(int));
int rulebook_len = 0; int rulebook_len = 0;
// calc the rulebook_len // calc the rulebook_len
...@@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx, ...@@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx,
} }
if (rulebook_ptr == nullptr) { if (rulebook_ptr == nullptr) {
counter_ptr[kernel_index - 1] += 1; counter_per_kernel[kernel_index - 1] += 1;
++rulebook_len; ++rulebook_len;
} else { } else {
rulebook_ptr[rulebook_index] = kernel_index - 1; rulebook_ptr[rulebook_index] = kernel_index - 1;
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h" #include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -34,22 +34,27 @@ template <typename T, typename IntT = int> ...@@ -34,22 +34,27 @@ template <typename T, typename IntT = int>
void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims(); const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int in_channels = kernel_dims[3]; const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4]; const int out_channels = kernel_dims[4];
const IntT* rulebook_ptr = rulebook.data<IntT>();
const int rulebook_len = rulebook.dims()[1]; int rulebook_len = 0;
const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr<IntT>(
out, rulebook, key, &rulebook_len);
const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key);
DenseTensorMeta in_features_meta( DenseTensorMeta in_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
...@@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, ...@@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
&x_grad_indices); &x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0); std::vector<IntT> offsets(kernel_size + 1);
for (int i = 0; i < rulebook_len; i++) { IntT offset = 0;
counter[rulebook_ptr[i]] += 1; int max_count = 0;
}
IntT offset = 0, max_count = 0;
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset; offsets[i] = offset;
offset += counter[i]; offset += counter_ptr[i];
if (i < half_kernel_size) { if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]); max_count = std::max(max_count, counter_ptr[i]);
} }
} }
offsets[kernel_size] = offset; offsets[kernel_size] = offset;
...@@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, ...@@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) { if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue; continue;
} }
const int M = counter[i]; const int M = counter_ptr[i];
const int K = in_channels; const int K = in_channels;
const int N = out_channels; const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
...@@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, ...@@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
// 4. scatter // 4. scatter
Scatter<T, IntT>(d_x_features_ptr, Scatter<T, IntT>(d_x_features_ptr,
rulebook.data<IntT>() + rulebook_len, rulebook_ptr + rulebook_len,
rulebook_len, rulebook_len,
in_channels, in_channels,
x_grad_values_ptr); x_grad_values_ptr);
...@@ -181,13 +184,16 @@ template <typename T, typename Context> ...@@ -181,13 +184,16 @@ template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx, void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
...@@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx, ...@@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradCPUKernel<T, data_t>(dev_ctx, Conv3dCooGradCPUKernel<T, data_t>(dev_ctx,
x, x,
kernel, kernel,
out,
rulebook, rulebook,
counter,
out_grad, out_grad,
paddings, paddings,
dilations, dilations,
strides, strides,
groups, groups,
subm, subm,
key,
x_grad, x_grad,
kernel_grad); kernel_grad);
})); }));
......
...@@ -14,9 +14,10 @@ limitations under the License. */ ...@@ -14,9 +14,10 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h" #include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, ...@@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
// update padding and dilation // update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1 // Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight) // if x.layout != NDHWC then transpose(x), transpose(weight)
...@@ -66,26 +69,50 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, ...@@ -66,26 +69,50 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// Second algorithm: // Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook // 1. product rulebook
DenseTensorMeta counter_meta( DenseTensor h_counter, h_offsets;
DataType::INT32, {kernel_size}, DataLayout::NCHW); h_counter.Resize({kernel_size});
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); h_offsets.Resize({kernel_size + 1});
int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter);
ProductRuleBook<T, CPUContext, IntT>(dev_ctx, int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets);
x,
kernel_sizes, // DenseTensor* rulebook = nullptr;
subm_paddings, const IntT* rulebook_ptr = nullptr;
dilations, int n = 0;
subm_strides, bool need_product_rulebook = true;
out_dims, if (subm && !key.empty()) {
subm, rulebook_ptr = phi::funcs::sparse::PrepareSubm<T, IntT, CPUContext>(
rulebook, dev_ctx,
&counter_per_kernel); x,
key,
UpdateRulebookAndOutIndex<T, CPUContext, IntT>( out_dims,
dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out); out,
h_counter_ptr,
int n = rulebook->dims()[1]; h_offsets_ptr,
const int* counter_ptr = counter_per_kernel.data<int>(); &n,
&need_product_rulebook);
}
if (need_product_rulebook) {
DenseTensor tmp_rulebook;
ProductRuleBook<T, CPUContext, IntT>(dev_ctx,
x,
kernel_sizes,
subm_paddings,
dilations,
subm_strides,
out_dims,
subm,
&tmp_rulebook,
h_counter_ptr);
UpdateRulebookAndOutIndex<T, CPUContext, IntT>(
dev_ctx, x, kernel_size, out_channels, out_dims, &tmp_rulebook, out);
n = tmp_rulebook.dims()[1];
rulebook_ptr = tmp_rulebook.data<IntT>();
phi::funcs::sparse::SaveToTable(
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
// int n = rulebook->dims()[1];
// 2. gather // 2. gather
DenseTensorMeta in_features_meta( DenseTensorMeta in_features_meta(
...@@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, ...@@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
T* out_features_ptr = out_features.data<T>(); T* out_features_ptr = out_features.data<T>();
Gather<T, IntT>(x.non_zero_elements().data<T>(), Gather<T, IntT>(x.non_zero_elements().data<T>(),
rulebook->data<IntT>() + n, rulebook_ptr + n,
n, n,
in_channels, in_channels,
in_features_ptr); in_features_ptr);
// 3. call gemm for every werght // 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
std::vector<int> offsets(kernel_size + 1);
int offset = 0; int offset = 0;
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset; h_offsets_ptr[i] = offset;
offset += counter_ptr[i]; offset += h_counter_ptr[i];
} }
offsets[kernel_size] = offset; h_offsets_ptr[kernel_size] = offset;
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0) { if (h_counter_ptr[i] <= 0) {
continue; continue;
} }
// call gemm: (n, in_channels) * (in_channels, out_channels) // call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = counter_ptr[i]; const int M = h_counter_ptr[i];
const int K = in_channels; // in_channels const int K = in_channels; // in_channels
const int N = out_channels; // out_channels const int N = out_channels; // out_channels
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N; const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels; T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans, blas.GEMM(CblasNoTrans,
CblasNoTrans, CblasNoTrans,
M, M,
...@@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, ...@@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// 4. scatter // 4. scatter
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>(); T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels); memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
Scatter<T, IntT>(out_features_ptr, Scatter<T, IntT>(
rulebook->data<IntT>() + n * 2, out_features_ptr, rulebook_ptr + n * 2, n, out_channels, out_values_ptr);
n,
out_channels,
out_values_ptr);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] { x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] {
Conv3dCooCPUKernel<T, data_t>(dev_ctx, Conv3dCooCPUKernel<T, data_t>(dev_ctx,
...@@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides, strides,
groups, groups,
subm, subm,
key,
out, out,
rulebook); rulebook,
counter);
})); }));
} }
......
...@@ -28,6 +28,7 @@ template <typename T, typename IntT = int> ...@@ -28,6 +28,7 @@ template <typename T, typename IntT = int>
void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, ...@@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
const int channels = x.dims()[4]; const int channels = x.dims()[4];
int rulebook_len = rulebook.dims()[1]; int rulebook_len = rulebook.dims()[1];
const IntT* rulebook_ptr = rulebook.data<IntT>(); const IntT* rulebook_ptr = rulebook.data<IntT>();
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0); std::vector<int> offsets(kernel_size + 1);
for (int i = 0; i < rulebook_len; i++) { const int* counter_ptr = counter.data<int>();
counter[rulebook_ptr[i]] += 1;
} phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size);
const T* in_features_ptr = x.non_zero_elements().data<T>(); const T* in_features_ptr = x.non_zero_elements().data<T>();
const T* out_features_ptr = out.non_zero_elements().data<T>(); const T* out_features_ptr = out.non_zero_elements().data<T>();
...@@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, ...@@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
phi::funcs::MaxPoolGrad<T> grad_functor; phi::funcs::MaxPoolGrad<T> grad_functor;
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
for (int j = 0; j < counter[i]; j++) { for (int j = 0; j < counter_ptr[i]; j++) {
IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j];
IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j];
for (int c = 0; c < channels; c++) { for (int c = 0; c < channels; c++) {
...@@ -78,6 +78,7 @@ template <typename T, typename Context> ...@@ -78,6 +78,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx, void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx, ...@@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] {
MaxPoolCooGradCPUKernel<T, data_t>( MaxPoolCooGradCPUKernel<T, data_t>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
})); }));
} }
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h" #include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, ...@@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
const std::vector<int>& real_kernel_sizes = const std::vector<int>& real_kernel_sizes =
...@@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, ...@@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims);
const int in_channels = real_kernel_sizes[3]; const int in_channels = real_kernel_sizes[3];
DenseTensorMeta counter_meta( std::vector<int> counter_per_kernel(kernel_size, 0);
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
const T* in_features_ptr = x.non_zero_elements().data<T>(); const T* in_features_ptr = x.non_zero_elements().data<T>();
// 1. product rule book // 1. product rule book
...@@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, ...@@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
out_dims, out_dims,
false, false,
rulebook, rulebook,
&counter_per_kernel); counter_per_kernel.data());
UpdateRulebookAndOutIndex<T, CPUContext, IntT>( UpdateRulebookAndOutIndex<T, CPUContext, IntT>(
dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out); dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out);
int rulebook_len = rulebook->dims()[1]; int rulebook_len = rulebook->dims()[1];
const IntT* rulebook_ptr = rulebook->data<IntT>(); const IntT* rulebook_ptr = rulebook->data<IntT>();
const int* counter_ptr = counter_per_kernel.data<int>();
counter->Resize({kernel_size});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, counter_per_kernel.data(), kernel_size * sizeof(int));
std::vector<int> offsets(kernel_size + 1); std::vector<int> offsets(kernel_size + 1);
phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size); phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
...@@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] {
MaxPoolCooCPUKernel<T, data_t>(dev_ctx, MaxPoolCooCPUKernel<T, data_t>(dev_ctx,
...@@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations, dilations,
strides, strides,
out, out,
rulebook); rulebook,
counter);
})); }));
} }
......
...@@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx, ...@@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
} }
// 5. scatter the values // 5. scatter the values
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); const int VecSize = VecBytes / sizeof(T);
phi::funcs::sparse::ScatterKernel<T> if (stride % VecSize == 0) {
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>( config = phi::backends::gpu::GetGpuLaunchConfig1D(
x_values_ptr, dev_ctx, nnz * stride / VecSize, 1);
public_indexs.data<int>(), phi::funcs::sparse::ScatterKernel<T, VecSize>
values_indexs_ptr, <<<config.block_per_grid,
out_nnz, config.thread_per_block,
nnz, 0,
stride, dev_ctx.stream()>>>(x_values_ptr,
out_values.data<T>()); public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
phi::funcs::sparse::ScatterKernel<T, 1>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
}
// 6. convert index to coordinate // 6. convert index to coordinate
Dim<DDim::kMaxRank> const_dims; Dim<DDim::kMaxRank> const_dims;
......
此差异已折叠。
...@@ -19,13 +19,11 @@ limitations under the License. */ ...@@ -19,13 +19,11 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -42,43 +40,42 @@ template <typename T, typename IntT> ...@@ -42,43 +40,42 @@ template <typename T, typename IntT>
void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims(); const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int in_channels = kernel_dims[3]; const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4]; const int out_channels = kernel_dims[4];
const IntT* rulebook_ptr = rulebook.data<IntT>();
const int rulebook_len = rulebook.dims()[1]; int rulebook_len = 0;
const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr<IntT>(
out, rulebook, key, &rulebook_len);
const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key);
DenseTensorMeta in_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta d_x_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta out_grad_features_meta(
x.dtype(), {rulebook_len, out_channels}, DataLayout::NCHW);
phi::DenseTensor in_features = phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta)); phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor d_x_features = phi::DenseTensor d_x_features =
phi::Empty(dev_ctx, std::move(d_x_features_meta)); phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_grad_features = phi::DenseTensor out_grad_features =
phi::Empty(dev_ctx, std::move(out_grad_features_meta)); phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>(); T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>(); T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>(); T* out_grad_features_ptr = out_grad_features.data<T>();
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel); *kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
T* d_kernel_ptr = kernel_grad->data<T>(); T* d_kernel_ptr = kernel_grad->data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::backends::gpu::GpuMemsetAsync(
set_zero(dev_ctx, kernel_grad, static_cast<T>(0.0f)); d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream());
int half_kernel_size = kernel_size / 2; int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
...@@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices()); phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements()); DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
T* x_grad_values_ptr = x_grad_values.data<T>(); T* x_grad_values_ptr = x_grad_values.data<T>();
set_zero(dev_ctx, &x_grad_values, static_cast<T>(0.0f)); phi::backends::gpu::GpuMemsetAsync(x_grad_values_ptr,
set_zero(dev_ctx, &d_x_features, static_cast<T>(0.0f)); 0,
sizeof(T) * x_grad_values.numel(),
dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(
d_x_features_ptr, 0, sizeof(T) * d_x_features.numel(), dev_ctx.stream());
phi::Copy<GPUContext>(dev_ctx, phi::Copy<GPUContext>(dev_ctx,
x.non_zero_indices(), x.non_zero_indices(),
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
...@@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
&x_grad_indices); &x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0), std::vector<int> offsets(kernel_size + 1);
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
rulebook_ptr,
rulebook_len * sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < rulebook_len; i++) { int offset = 0, max_count = 0;
counter[h_counter[i]] += 1;
}
IntT offset = 0, max_count = 0;
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset; offsets[i] = offset;
offset += counter[i]; offset += counter_ptr[i];
if (i < half_kernel_size) { if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]); max_count = std::max(max_count, counter_ptr[i]);
} }
} }
offsets[kernel_size] = offset; offsets[kernel_size] = offset;
...@@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
} }
} }
auto config = phi::backends::gpu::GetGpuLaunchConfig1D( auto config =
dev_ctx, rulebook_len * in_channels, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
GatherKernel<T, IntT><<<config.block_per_grid.x, DenseTensor unique_value = phi::Empty<int>(
config.thread_per_block.x, dev_ctx, {static_cast<int>(x_grad->nnz() * kernel_size * 2)});
0, DenseTensor out_index =
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(), phi::Empty<int>(dev_ctx, {static_cast<int>(x.nnz() * 2)});
rulebook_ptr + rulebook_len, int* out_index_ptr = out_index.data<int>();
in_features_ptr, int* unique_value_ptr = unique_value.data<int>();
rulebook_len, phi::backends::gpu::GpuMemsetAsync(
in_channels); out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());
config = phi::backends::gpu::GetGpuLaunchConfig1D( GroupIndexsV2<<<config.block_per_grid,
dev_ctx, rulebook_len * out_channels, 1); config.thread_per_block,
GatherKernel<T, IntT> 0,
<<<config.block_per_grid.x, dev_ctx.stream()>>>(rulebook_len,
config.thread_per_block.x, x.nnz(),
0, kernel_size,
dev_ctx.stream()>>>(out_grad.non_zero_elements().data<T>(), offsets[kernel_size / 2],
rulebook_ptr + rulebook_len * 2, rulebook_ptr,
out_grad_features_ptr, out_index_ptr,
rulebook_len, unique_value_ptr);
out_channels);
GatherV2<T, IntT>(dev_ctx,
x.non_zero_elements().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);
Gather<T, IntT>(dev_ctx,
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
out_channels,
out_grad_features_ptr);
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) { if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue; continue;
} }
const int M = counter[i]; const int M = counter_ptr[i];
const int K = in_channels; const int K = in_channels;
const int N = out_channels; const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
...@@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
} }
// 4. scatter // 4. scatter
config = phi::backends::gpu::GetGpuLaunchConfig1D( phi::funcs::sparse::ScatterV2<T>(dev_ctx,
dev_ctx, rulebook_len * in_channels, 1); d_x_features_ptr,
out_index.data<int>(),
phi::funcs::ScatterCUDAKernel<<<config.block_per_grid, unique_value.data<int>(),
config.thread_per_block, x_grad->nnz(),
0, kernel_size,
dev_ctx.stream()>>>( in_channels,
d_x_features_ptr, 2,
rulebook_ptr + rulebook_len, x_grad_values_ptr);
x_grad_values_ptr,
rulebook_len,
in_channels,
false);
} }
template <typename T, typename Context> template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx, void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& kernel, const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
...@@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx, ...@@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradGPUKernel<T, data_t>(dev_ctx, Conv3dCooGradGPUKernel<T, data_t>(dev_ctx,
x, x,
kernel, kernel,
out,
rulebook, rulebook,
counter,
out_grad, out_grad,
paddings, paddings,
dilations, dilations,
strides, strides,
groups, groups,
subm, subm,
key,
x_grad, x_grad,
kernel_grad); kernel_grad);
})); }));
......
...@@ -21,7 +21,9 @@ limitations under the License. */ ...@@ -21,7 +21,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "glog/logging.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
// update padding and dilation // update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1 // Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight) // if x.layout != NDHWC then transpose(x), transpose(weight)
...@@ -61,85 +65,117 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -61,85 +65,117 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims);
const int in_channels = kernel_dims[3]; const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4]; const int out_channels = kernel_dims[4];
std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size); DenseTensor h_counter, h_offsets;
h_counter.Resize({kernel_size});
h_offsets.Resize({kernel_size + 1});
int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter);
int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets);
// Second algorithm: // Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook // 1. product rulebook
DenseTensorMeta counter_meta( DenseTensor counter_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size});
DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor offsets_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size});
DenseTensorMeta offsets_meta( DenseTensor out_index = phi::Empty<int>(dev_ctx, {1});
DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor unique_value = phi::Empty<int>(dev_ctx, {1});
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta)); VLOG(6) << "call SubmConv3D or Conv3D " << subm << " and the key is " << key;
DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); int rulebook_len = 0;
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); const IntT* rulebook_ptr = nullptr;
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); bool need_product_rulebook = true;
if (subm && !key.empty()) {
int n = ProductRuleBook<T, GPUContext, IntT>(dev_ctx, rulebook_ptr = phi::funcs::sparse::PrepareSubm<T, IntT, GPUContext>(
x, dev_ctx,
kernel_sizes, x,
subm_paddings, key,
dilations, out_dims,
subm_strides, out,
out_dims, h_counter.data<int>(),
subm, h_offsets.data<int>(),
rulebook, &rulebook_len,
&counter_per_kernel, &need_product_rulebook);
&offsets_per_kernel, }
&out_index,
&unique_value, if (need_product_rulebook) {
out, DenseTensor tmp_rulebook;
&h_counter, rulebook_len = ProductRuleBook<T, GPUContext, IntT>(dev_ctx,
&offsets); x,
kernel_sizes,
const int* counter_ptr = counter_per_kernel.data<int>(); subm_paddings,
const int* offsets_ptr = counter_per_kernel.data<int>(); dilations,
const IntT* rulebook_ptr = rulebook->data<IntT>(); subm_strides,
out_dims,
subm,
&tmp_rulebook,
&counter_per_kernel,
&offsets_per_kernel,
&out_index,
&unique_value,
out,
h_counter_ptr,
h_offsets_ptr);
rulebook_ptr = tmp_rulebook.data<IntT>();
phi::funcs::sparse::SaveToTable(
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
// 2. gather // 2. gather
DenseTensorMeta in_features_meta(
x.dtype(), {n, in_channels}, DataLayout::NCHW);
DenseTensorMeta out_features_meta(
x.dtype(), {n, out_channels}, DataLayout::NCHW);
phi::DenseTensor in_features = phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta)); phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features = phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta)); phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>(); T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>(); T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f)); set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
auto config = Gather<T, IntT>(dev_ctx,
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); x.non_zero_elements().data<T>(),
GatherKernel<T, IntT><<<config.block_per_grid.x, rulebook_ptr,
config.thread_per_block.x, rulebook_len,
0, in_channels,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(), in_features_ptr);
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
// 3. call gemm for every werght // 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_non_zero_elements(); auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>(); T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (h_counter[i] <= 0) { if (h_counter_ptr[i] <= 0) {
continue; continue;
} }
// call gemm: (n, in_channels) * (in_channels, out_channels) // call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = h_counter[i]; const int M = h_counter_ptr[i];
const int K = in_channels; const int K = in_channels;
const int N = out_channels; const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N; const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels; T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans, blas.GEMM(CblasNoTrans,
CblasNoTrans, CblasNoTrans,
...@@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
} }
// 4. scatter // 4. scatter
if (subm) { phi::funcs::sparse::ScatterV2<T>(dev_ctx,
set_zero(dev_ctx, out_values, static_cast<T>(0.0f)); out_features_ptr,
config = out_index.data<int>(),
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1); unique_value.data<int>(),
phi::funcs::ScatterCUDAKernel<T, IntT> out->nnz(),
<<<config.block_per_grid, kernel_size,
config.thread_per_block, out_channels,
0, 1,
dev_ctx.stream()>>>(out_features_ptr, out_values_ptr);
rulebook_ptr + 2 * n,
out_values_ptr,
n,
out_channels,
false);
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels, 1);
phi::funcs::sparse::ScatterKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
out->nnz(),
n,
out_channels,
out_values_ptr);
}
} }
/** /**
* x: (N, D, H, W, C) * x: the input SparseCooTensor, shape is (N, D, H, W, C)
* kernel: (D, H, W, C, OC) * kernel: the weight data, shape is (D, H, W, C, OC)
* out: (N, D, H, W, OC) * out: the output SparseCooTensor, shape is (N, D, H, W, OC)
* rulebook: return rulebook if key is not vailed else return nullptr
* counter: return counter if key is not vailed else return nullptr
**/ **/
template <typename T, typename Context> template <typename T, typename Context>
void Conv3dCooKernel(const Context& dev_ctx, void Conv3dCooKernel(const Context& dev_ctx,
...@@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides, const std::vector<int>& strides,
const int groups, const int groups,
const bool subm, const bool subm,
const std::string& key,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooGPUKernel", ([&] { x.non_zero_indices().dtype(), "Conv3dCooGPUKernel", ([&] {
Conv3dCooGPUKernel<T, data_t>(dev_ctx, Conv3dCooGPUKernel<T, data_t>(dev_ctx,
...@@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides, strides,
groups, groups,
subm, subm,
key,
out, out,
rulebook); rulebook,
counter);
})); }));
} }
......
...@@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
x_indexs_ptr, x_indexs.numel(), table.data<int>()); x_indexs_ptr, x_indexs.numel(), table.data<int>());
config = config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
const int VecBytes = 16; const int VecBytes = 16;
const int VecSize = VecBytes / sizeof(T); const int VecSize = VecBytes / sizeof(T);
if (stride % VecSize == 0) { if (stride % VecSize == 0) {
......
...@@ -55,6 +55,7 @@ template <typename T, typename IntT = int> ...@@ -55,6 +55,7 @@ template <typename T, typename IntT = int>
void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
const int in_channels = x.dims()[4]; const int in_channels = x.dims()[4];
int rulebook_len = rulebook.dims()[1]; int rulebook_len = rulebook.dims()[1];
const IntT* rulebook_ptr = rulebook.data<IntT>(); const IntT* rulebook_ptr = rulebook.data<IntT>();
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0), std::vector<int> offsets(kernel_size + 1);
h_counter(rulebook_len, 0); const int* counter_ptr = counter.data<int>();
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
rulebook_ptr,
rulebook_len * sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < rulebook_len; i++) {
counter[h_counter[i]] += 1;
}
phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size);
const T* in_features_ptr = x.non_zero_elements().data<T>(); const T* in_features_ptr = x.non_zero_elements().data<T>();
const T* out_features_ptr = out.non_zero_elements().data<T>(); const T* out_features_ptr = out.non_zero_elements().data<T>();
...@@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
&x_grad_indices); &x_grad_indices);
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) { if (counter_ptr[i] <= 0) {
continue; continue;
} }
auto config = phi::backends::gpu::GetGpuLaunchConfig1D( auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, counter[i] * in_channels, 1); dev_ctx, counter_ptr[i] * in_channels, 1);
MaxPoolGradCudaKernel<T, IntT> MaxPoolGradCudaKernel<T, IntT>
<<<config.block_per_grid.x, <<<config.block_per_grid.x,
config.thread_per_block.x, config.thread_per_block.x,
...@@ -112,8 +99,8 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -112,8 +99,8 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
dev_ctx.stream()>>>(in_features_ptr, dev_ctx.stream()>>>(in_features_ptr,
out_features_ptr, out_features_ptr,
out_grad_ptr, out_grad_ptr,
rulebook_ptr + offsets[i] + rulebook_len, rulebook_ptr + offsets[i],
counter[i], counter_ptr[i],
rulebook_len, rulebook_len,
in_channels, in_channels,
x_grad_ptr); x_grad_ptr);
...@@ -124,6 +111,7 @@ template <typename T, typename Context> ...@@ -124,6 +111,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx, void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx, ...@@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradGPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooGradGPUKernel", ([&] {
MaxPoolCooGradGPUKernel<T, data_t>( MaxPoolCooGradGPUKernel<T, data_t>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
})); }));
} }
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, ...@@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
const std::vector<int>& real_kernel_sizes = const std::vector<int>& real_kernel_sizes =
...@@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, ...@@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims);
const int in_channels = real_kernel_sizes[3]; const int in_channels = real_kernel_sizes[3];
std::vector<int> offsets(kernel_size + 1), counter(kernel_size); std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size);
DenseTensorMeta counter_meta( DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW); DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
...@@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, ...@@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
&out_index, &out_index,
&unique_value, &unique_value,
out, out,
&counter, h_counter.data(),
&offsets); offsets.data());
const IntT* rulebook_ptr = rulebook->data<IntT>(); const IntT* rulebook_ptr = rulebook->data<IntT>();
T* out_features_ptr = out->mutable_non_zero_elements()->data<T>(); T* out_features_ptr = out->mutable_non_zero_elements()->data<T>();
const T* in_features_ptr = x.non_zero_elements().data<T>(); const T* in_features_ptr = x.non_zero_elements().data<T>();
counter->Resize({kernel_size});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, h_counter.data(), h_counter.size() * sizeof(int));
// 2. max pool // 2. max pool
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
thrust::fill(thrust::hip::par.on(dev_ctx.stream()), thrust::fill(thrust::hip::par.on(dev_ctx.stream()),
...@@ -107,22 +111,21 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, ...@@ -107,22 +111,21 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
static_cast<T>(0)); static_cast<T>(0));
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster // TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) { if (h_counter[i] <= 0) {
continue; continue;
} }
auto config = phi::backends::gpu::GetGpuLaunchConfig1D( auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, counter[i] * in_channels, 1); dev_ctx, h_counter[i] * in_channels, 1);
MaxPoolCudaKernel<T, IntT> MaxPoolCudaKernel<T, IntT><<<config.block_per_grid.x,
<<<config.block_per_grid.x, config.thread_per_block.x,
config.thread_per_block.x, 0,
0, dev_ctx.stream()>>>(in_features_ptr,
dev_ctx.stream()>>>(in_features_ptr, rulebook_ptr + offsets[i],
rulebook_ptr + offsets[i] + rulebook_len, h_counter[i],
counter[i], rulebook_len,
rulebook_len, in_channels,
in_channels, out_features_ptr);
out_features_ptr);
} }
} }
...@@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooGPUKernel", ([&] {
MaxPoolCooGPUKernel<T, data_t>(dev_ctx, MaxPoolCooGPUKernel<T, data_t>(dev_ctx,
...@@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations, dilations,
strides, strides,
out, out,
rulebook); rulebook,
counter);
})); }));
} }
......
...@@ -25,6 +25,7 @@ template <typename T, typename Context> ...@@ -25,6 +25,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx, void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -34,12 +35,13 @@ template <typename T, typename Context> ...@@ -34,12 +35,13 @@ template <typename T, typename Context>
SparseCooTensor MaxPoolCooGrad(const Context& dev_ctx, SparseCooTensor MaxPoolCooGrad(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& rulebook, const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out, const SparseCooTensor& out,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes) { const std::vector<int>& kernel_sizes) {
SparseCooTensor x_grad; SparseCooTensor x_grad;
MaxPoolCooGradKernel<T, Context>( MaxPoolCooGradKernel<T, Context>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, &x_grad);
return x_grad; return x_grad;
} }
......
...@@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook); DenseTensor* rulebook,
DenseTensor* counter);
template <typename T, typename Context> template <typename T, typename Context>
SparseCooTensor MaxPoolCoo(const Context& dev_ctx, SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
...@@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx, ...@@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
DenseTensor* rulebook) { DenseTensor* rulebook,
DenseTensor* counter) {
SparseCooTensor coo; SparseCooTensor coo;
MaxPoolCooKernel<T, Context>( MaxPoolCooKernel<T, Context>(dev_ctx,
dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); x,
kernel_sizes,
paddings,
dilations,
strides,
&coo,
rulebook,
counter);
return coo; return coo;
} }
......
...@@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector<int>& indices, ...@@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector<int>& indices,
kernel.size() * sizeof(T)); kernel.size() * sizeof(T));
if (!std::is_same<T, phi::dtype::float16>::value) { if (!std::is_same<T, phi::dtype::float16>::value) {
auto tensor_out = paddle::experimental::sparse::conv3d( auto tensor_out = paddle::experimental::sparse::conv3d_coo(
x, weight, paddings, dilations, strides, 1, false); x, weight, paddings, dilations, strides, 1, false, "Conv3d");
auto out = auto out =
std::dynamic_pointer_cast<phi::SparseCooTensor>(tensor_out.impl()); std::dynamic_pointer_cast<phi::SparseCooTensor>(tensor_out.impl());
......
...@@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
}; };
if (!std::is_same<T, phi::dtype::float16>::value) { if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook = phi::Empty( DenseTensor rulebook, counter;
dev_ctx_cpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW));
SparseCooTensor out = sparse::Conv3dCoo<T>(dev_ctx_cpu, SparseCooTensor out = sparse::Conv3dCoo<T>(dev_ctx_cpu,
x_tensor, x_tensor,
kernel_tensor, kernel_tensor,
...@@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides, strides,
1, 1,
subm, subm,
&rulebook); "Conv3d",
&rulebook,
&counter);
ASSERT_EQ(correct_out_dims.size(), out.dims().size()); ASSERT_EQ(correct_out_dims.size(), out.dims().size());
for (int i = 0; i < correct_out_dims.size(); i++) { for (int i = 0; i < correct_out_dims.size(); i++) {
...@@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse::Conv3dCooGrad<T>(dev_ctx_cpu, sparse::Conv3dCooGrad<T>(dev_ctx_cpu,
x_tensor, x_tensor,
kernel_tensor, kernel_tensor,
out,
rulebook, rulebook,
counter,
out, out,
paddings, paddings,
dilations, dilations,
strides, strides,
1, 1,
subm); subm,
"Conv3d");
f_verify(std::get<0>(grads).non_zero_elements().data<T>(), features_grad); f_verify(std::get<0>(grads).non_zero_elements().data<T>(), features_grad);
f_verify(std::get<1>(grads).data<T>(), kernel_grad); f_verify(std::get<1>(grads).data<T>(), kernel_grad);
} }
...@@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
phi::Copy( phi::Copy(
dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor); dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor);
DenseTensor d_rulebook = phi::Empty( DenseTensor d_rulebook, d_counter;
dev_ctx_gpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW));
SparseCooTensor d_out = sparse::Conv3dCoo<T>(dev_ctx_gpu, SparseCooTensor d_out = sparse::Conv3dCoo<T>(dev_ctx_gpu,
d_x_tensor, d_x_tensor,
d_kernel_tensor, d_kernel_tensor,
...@@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides, strides,
1, 1,
subm, subm,
&d_rulebook); "Conv3d",
&d_rulebook,
&d_counter);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out); SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
...@@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse::Conv3dCooGrad<T>(dev_ctx_gpu, sparse::Conv3dCooGrad<T>(dev_ctx_gpu,
d_x_tensor, d_x_tensor,
d_kernel_tensor, d_kernel_tensor,
d_out,
d_rulebook, d_rulebook,
d_counter,
d_out, d_out,
paddings, paddings,
dilations, dilations,
strides, strides,
1, 1,
subm); subm,
"Conv3d");
DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements(); DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements();
DenseTensor d_kernel_grad = std::get<1>(grads); DenseTensor d_kernel_grad = std::get<1>(grads);
DenseTensor h_features_grad = DenseTensor h_features_grad =
......
...@@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
}; };
if (!std::is_same<T, phi::dtype::float16>::value) { if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook; DenseTensor rulebook, counter;
SparseCooTensor out = sparse::MaxPoolCoo<T>(dev_ctx_cpu, SparseCooTensor out = sparse::MaxPoolCoo<T>(dev_ctx_cpu,
x_tensor, x_tensor,
kernel_sizes, kernel_sizes,
paddings, paddings,
dilations, dilations,
strides, strides,
&rulebook); &rulebook,
&counter);
ASSERT_EQ(correct_out_dims.size(), out.dims().size()); ASSERT_EQ(correct_out_dims.size(), out.dims().size());
for (int i = 0; i < correct_out_dims.size(); i++) { for (int i = 0; i < correct_out_dims.size(); i++) {
...@@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
if (backward) { if (backward) {
SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>( SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>(
dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes); dev_ctx_cpu, x_tensor, rulebook, counter, out, out, kernel_sizes);
f_verify(x_grad.non_zero_elements().data<T>(), features_grad); f_verify(x_grad.non_zero_elements().data<T>(), features_grad);
} }
} }
...@@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims);
DenseTensor d_rulebook; DenseTensor d_rulebook, d_counter;
SparseCooTensor d_out = sparse::MaxPoolCoo<T>(dev_ctx_gpu, SparseCooTensor d_out = sparse::MaxPoolCoo<T>(dev_ctx_gpu,
d_x_tensor, d_x_tensor,
kernel_sizes, kernel_sizes,
paddings, paddings,
dilations, dilations,
strides, strides,
&d_rulebook); &d_rulebook,
&d_counter);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out); SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
...@@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
f_verify(h_features_tensor.data<T>(), correct_out_features); f_verify(h_features_tensor.data<T>(), correct_out_features);
if (backward) { if (backward) {
SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>( SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>(dev_ctx_gpu,
dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes); d_x_tensor,
d_rulebook,
d_counter,
d_out,
d_out,
kernel_sizes);
DenseTensor h_features_grad = DenseTensor h_features_grad =
phi::EmptyLike<T>(dev_ctx_cpu, x_grad.non_zero_elements()); phi::EmptyLike<T>(dev_ctx_cpu, x_grad.non_zero_elements());
phi::Copy(dev_ctx_gpu, phi::Copy(dev_ctx_gpu,
......
...@@ -67,7 +67,7 @@ class TestSparseConv(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestSparseConv(unittest.TestCase):
indices, values, dense_shape, stop_gradient=True) indices, values, dense_shape, stop_gradient=True)
weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32')
y = paddle.incubate.sparse.nn.functional.subm_conv3d( y = paddle.incubate.sparse.nn.functional.subm_conv3d(
sparse_x, weight) sparse_x, weight, key='subm_conv')
assert np.array_equal(sparse_x.indices().numpy(), assert np.array_equal(sparse_x.indices().numpy(),
y.indices().numpy()) y.indices().numpy())
...@@ -91,7 +91,7 @@ class TestSparseConv(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestSparseConv(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC' #Currently, only support data_format='NDHWC'
conv3d = paddle.incubate.sparse.nn.SubmConv3D( conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW') 1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv')
def test_SubmConv3D(self): def test_SubmConv3D(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -105,7 +105,7 @@ class TestSparseConv(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestSparseConv(unittest.TestCase):
indices, values, dense_shape, False) indices, values, dense_shape, False)
subm_conv3d = paddle.incubate.sparse.nn.SubmConv3D( subm_conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NDHWC') 1, 1, (1, 3, 3), data_format='NDHWC', key='subm_conv')
# test extra_repr # test extra_repr
print(subm_conv3d.extra_repr()) print(subm_conv3d.extra_repr())
...@@ -117,7 +117,7 @@ class TestSparseConv(unittest.TestCase): ...@@ -117,7 +117,7 @@ class TestSparseConv(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC' #Currently, only support data_format='NDHWC'
conv3d = paddle.incubate.sparse.nn.SubmConv3D( conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW') 1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv')
def test_Conv3D_bias(self): def test_Conv3D_bias(self):
with _test_eager_guard(): with _test_eager_guard():
......
...@@ -29,6 +29,7 @@ def _conv3d(x, ...@@ -29,6 +29,7 @@ def _conv3d(x,
dilation=1, dilation=1,
groups=1, groups=1,
subm=False, subm=False,
key=None,
data_format="NDHWC", data_format="NDHWC",
name=None): name=None):
assert in_dynamic_mode(), "Currently, only support dynamic mode" assert in_dynamic_mode(), "Currently, only support dynamic mode"
...@@ -62,8 +63,9 @@ def _conv3d(x, ...@@ -62,8 +63,9 @@ def _conv3d(x,
dilation = convert_to_list(dilation, dims, 'dilation') dilation = convert_to_list(dilation, dims, 'dilation')
op_type = "conv3d" op_type = "conv3d"
pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation, pre_bias = _C_ops.final_state_sparse_conv3d_coo(
stride, groups, subm) x, weight, padding, dilation, stride, groups, subm,
key if key is not None else "")
if bias is not None: if bias is not None:
values = pre_bias.values() values = pre_bias.values()
add_bias = elementwise_add(values, bias, axis=1) add_bias = elementwise_add(values, bias, axis=1)
...@@ -186,7 +188,7 @@ def conv3d(x, ...@@ -186,7 +188,7 @@ def conv3d(x,
# (1, 1, 1, 2, 1) # (1, 1, 1, 2, 1)
""" """
return _conv3d(x, weight, bias, stride, padding, dilation, groups, False, return _conv3d(x, weight, bias, stride, padding, dilation, groups, False,
data_format, name) None, data_format, name)
def subm_conv3d(x, def subm_conv3d(x,
...@@ -197,6 +199,7 @@ def subm_conv3d(x, ...@@ -197,6 +199,7 @@ def subm_conv3d(x,
dilation=1, dilation=1,
groups=1, groups=1,
data_format="NDHWC", data_format="NDHWC",
key=None,
name=None): name=None):
r""" r"""
...@@ -274,6 +277,10 @@ def subm_conv3d(x, ...@@ -274,6 +277,10 @@ def subm_conv3d(x,
will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`. will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of: The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of:
`[batch_size, input_depth, input_height, input_width, input_channels]`. `[batch_size, input_depth, input_height, input_width, input_channels]`.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
name(str|None): For detailed information, please refer name(str|None): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
...@@ -301,4 +308,4 @@ def subm_conv3d(x, ...@@ -301,4 +308,4 @@ def subm_conv3d(x,
#(1, 1, 3, 4, 1) #(1, 1, 3, 4, 1)
""" """
return _conv3d(x, weight, bias, stride, padding, dilation, groups, True, return _conv3d(x, weight, bias, stride, padding, dilation, groups, True,
data_format, name) key, data_format, name)
...@@ -33,6 +33,7 @@ class _Conv3D(Layer): ...@@ -33,6 +33,7 @@ class _Conv3D(Layer):
dilation=1, dilation=1,
groups=1, groups=1,
subm=False, subm=False,
key=None,
padding_mode='zeros', padding_mode='zeros',
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
...@@ -46,6 +47,7 @@ class _Conv3D(Layer): ...@@ -46,6 +47,7 @@ class _Conv3D(Layer):
self._out_channels = out_channels self._out_channels = out_channels
self._data_format = data_format self._data_format = data_format
self._subm = subm self._subm = subm
self._key = key
assert padding_mode == 'zeros', "Currently, only support padding_mode='zeros'" assert padding_mode == 'zeros', "Currently, only support padding_mode='zeros'"
assert groups == 1, "Currently, only support groups=1" assert groups == 1, "Currently, only support groups=1"
...@@ -95,6 +97,7 @@ class _Conv3D(Layer): ...@@ -95,6 +97,7 @@ class _Conv3D(Layer):
dilation=self._dilation, dilation=self._dilation,
groups=self._groups, groups=self._groups,
subm=self._subm, subm=self._subm,
key=self._key,
data_format=self._data_format) data_format=self._data_format)
return out return out
...@@ -240,6 +243,7 @@ class Conv3D(_Conv3D): ...@@ -240,6 +243,7 @@ class Conv3D(_Conv3D):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
subm=False, subm=False,
key=None,
padding_mode=padding_mode, padding_mode=padding_mode,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
...@@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D): ...@@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D):
of the input channels, while the second half of the filters is only of the input channels, while the second half of the filters is only
connected to the second half of the input channels. The default value is 1. connected to the second half of the input channels. The default value is 1.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``. padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv3d. If it is set to None or one attribute of ParamAttr, conv3d of conv3d. If it is set to None or one attribute of ParamAttr, conv3d
will create ParamAttr as param_attr. If it is set to None, the parameter will create ParamAttr as param_attr. If it is set to None, the parameter
...@@ -361,6 +369,7 @@ class SubmConv3D(_Conv3D): ...@@ -361,6 +369,7 @@ class SubmConv3D(_Conv3D):
dilation=1, dilation=1,
groups=1, groups=1,
padding_mode='zeros', padding_mode='zeros',
key=None,
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
data_format="NDHWC"): data_format="NDHWC"):
...@@ -372,6 +381,7 @@ class SubmConv3D(_Conv3D): ...@@ -372,6 +381,7 @@ class SubmConv3D(_Conv3D):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
subm=True, subm=True,
key=key,
padding_mode=padding_mode, padding_mode=padding_mode,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册