未验证 提交 9841b308 编写于 作者: Z zhangkaihuo 提交者: GitHub

Optimize sparse convolution (#43576)

上级 22342d51
......@@ -80,14 +80,14 @@
data_type : x
backward : cast_grad
- api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(out), Tensor(rulebook)
- api : conv3d_coo
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense}
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
backward : conv3d_grad
intermediate: rulebook, counter
backward : conv3d_coo_grad
- api : coo_to_dense
args : (Tensor x)
......@@ -352,11 +352,11 @@
- api: maxpool
args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output : Tensor(out), Tensor(rulebook)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : maxpool_coo{sparse_coo -> sparse_coo, dense}
func : maxpool_coo{sparse_coo -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
intermediate : rulebook, counter
backward : maxpool_grad
- api: mv
......
......@@ -81,12 +81,12 @@
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
data_type : out_grad
- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
- backward_api : conv3d_coo_grad
forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(x_grad), Tensor(kernel_grad)
kernel :
func : conv3d_coo_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
- backward_api : coo_to_dense_grad
forward : coo_to_dense(Tensor x) -> Tensor(out)
......@@ -164,11 +164,11 @@
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
- backward_api : maxpool_grad
forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes)
output : Tensor(x_grad)
kernel :
func : maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
func : maxpool_coo_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo}
- backward_api : multiply_grad
forward : multiply(Tensor x, Tensor y) -> Tensor(out)
......
......@@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase,
/// \brief get the dnese dim
int32_t dense_dim() const;
/// \brief query table according to key
const std::pair<DenseTensor, DenseTensor>* IndicesPairs(
const std::string& key) const {
if (indices_dict_ == nullptr) {
return nullptr;
}
const auto& iter = indices_dict_->find(key);
if (iter == indices_dict_->end()) {
return nullptr;
}
return &iter->second;
}
/// \brief save (key, indices_pairs)
void SaveIndicesPairs(
const std::string& key,
const std::pair<DenseTensor, DenseTensor>& indices_pairs) {
if (indices_dict_ == nullptr) {
indices_dict_ = std::make_shared<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>();
}
auto ret = indices_dict_->insert({key, indices_pairs});
if (ret.second == false) {
ret.first->second = indices_pairs;
}
}
/// \brief get indices_dict_
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
GetIndicesDict() const {
return indices_dict_;
}
/// \brief set indices_dict_
void SetIndicesDict(
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, DenseTensor>>>&
indices_dict) {
indices_dict_ = indices_dict;
}
private:
// save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_;
......@@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase,
bool coalesced_ = false;
// save the number of non zero elements in each batch
DDim dims_;
// for submanifold conv
// SubmConv will generate a rulebook and a counter, which can be
// reused by different SubmConv.
// refer to sparse/gpu/convolution_kernel.cu.
std::shared_ptr<std::map<std::string, std::pair<DenseTensor, DenseTensor>>>
indices_dict_ = nullptr;
/* --------------------------- */
/* example: non zero element is scalar */
/* --------------------------- */
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
......@@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) {
offsets[n] = offset;
}
template <typename IntT>
inline const IntT* GetRulebookPtr(const SparseCooTensor& coo,
const DenseTensor& rulebook,
const std::string& key,
int* rulebook_len) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
const DenseTensor& tmp_rulebook = indices_pairs->first;
*rulebook_len = tmp_rulebook.dims()[1];
return tmp_rulebook.data<IntT>();
}
}
*rulebook_len = rulebook.dims()[1];
return rulebook.data<IntT>();
}
inline const int* GetCounterPtr(const SparseCooTensor& coo,
const DenseTensor& counter,
const std::string& key) {
if (!key.empty()) {
const auto* indices_pairs = coo.IndicesPairs(key);
if (indices_pairs != nullptr) {
return indices_pairs->second.data<int>();
}
}
return counter.data<int>();
}
template <typename T, typename IntT, typename Context>
inline const IntT* PrepareSubm(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DDim& out_dims,
SparseCooTensor* out,
int* counter,
int* offsets,
int* rulebook_len,
bool* need_product_rulebook) {
const auto* indices_pairs = x.IndicesPairs(key);
if (indices_pairs != nullptr) {
*need_product_rulebook = false;
const DenseTensor& rulebook = indices_pairs->first;
const int counter_size = indices_pairs->second.numel();
memcpy(
counter, indices_pairs->second.data<int>(), counter_size * sizeof(int));
out->SetIndicesDict(x.GetIndicesDict());
*rulebook_len = rulebook.dims()[1];
DenseTensor out_indices =
phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
phi::Copy(
dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices);
out->SetMember(out_indices, out_values, out_dims, false);
PrefixSum<int>(counter, offsets, counter_size);
return rulebook.data<IntT>();
}
return nullptr;
}
template <typename Context>
inline void SaveToTable(const Context& dev_ctx,
const SparseCooTensor& x,
const std::string& key,
const DenseTensor& in_rulebook,
const DenseTensor& h_counter,
SparseCooTensor* out,
DenseTensor* out_rulebook,
DenseTensor* counter) {
out->SetIndicesDict(x.GetIndicesDict());
if (!key.empty()) {
out->SaveIndicesPairs(key, std::make_pair(in_rulebook, h_counter));
} else {
*out_rulebook = in_rulebook;
counter->Resize({h_counter.numel()});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, h_counter.data<int>(), h_counter.numel() * sizeof(int));
}
}
} // namespace sparse
} // namespace funcs
} // namespace phi
......@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define VecBytes 16
namespace phi {
namespace funcs {
......@@ -28,33 +33,126 @@ namespace sparse {
* channels: the output channel size
* out: the outputs
**/
template <typename T>
template <typename T, int VecSize>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;
int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
StoreT sums = {static_cast<T>(0)};
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}
// scatter's index has been grouped in advance
// index_counts record the count of each group
// index_groups save the index of each group
template <typename T, int VecSize>
__global__ void ScatterKernelV2(const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;
StoreT sums = {static_cast<T>(0)};
phi::Load<T, VecSize>(out + indices_i * channels + channels_i * VecSize,
&sums);
for (int it = 0; it < buffer_counts; it++) {
int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num;
for (int j = 0; j < len; j++) {
const int out_feature_i =
index_groups[indices_i * kernel_size + j + group_offset];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
}
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
out[indices_i * channels + channels_i] = sum;
}
template <typename T>
void ScatterV2(const GPUContext& dev_ctx,
const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* output) {
const int VecSize = VecBytes / sizeof(T);
if (channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels / VecSize, 1);
ScatterKernelV2<T, VecSize><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels, 1);
ScatterKernelV2<T, 1><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
}
}
......
......@@ -25,13 +25,16 @@ template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad);
......@@ -40,13 +43,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm) {
const bool subm,
const std::string& key) {
SparseCooTensor x_grad;
DenseTensor kernel_grad;
......@@ -54,13 +60,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
Conv3dCooGradKernel<T, Context>(dev_ctx,
x,
kernel,
out,
rulebook,
counter,
out_grad,
paddings,
dilations,
strides,
groups,
subm,
key,
&x_grad,
&kernel_grad);
return std::make_tuple(x_grad, kernel_grad);
......
......@@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook);
DenseTensor* rulebook,
DenseTensor* counter);
template <typename T, typename Context>
SparseCooTensor Conv3dCoo(const Context& dev_ctx,
......@@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* rulebook) {
const std::string& key,
DenseTensor* rulebook,
DenseTensor* counter) {
SparseCooTensor coo;
Conv3dCooKernel<T, Context>(dev_ctx,
x,
......@@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
strides,
groups,
subm,
key,
&coo,
rulebook);
rulebook,
counter);
return coo;
}
......
......@@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx,
const DDim& out_dims,
const bool subm,
DenseTensor* rulebook,
DenseTensor* counter_per_kernel) {
int* counter_per_kernel) {
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const IntT* indices_ptr = non_zero_indices.data<IntT>();
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
memset(counter_ptr, 0, kernel_size * sizeof(int));
memset(counter_per_kernel, 0, kernel_size * sizeof(int));
int rulebook_len = 0;
// calc the rulebook_len
......@@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx,
}
if (rulebook_ptr == nullptr) {
counter_ptr[kernel_index - 1] += 1;
counter_per_kernel[kernel_index - 1] += 1;
++rulebook_len;
} else {
rulebook_ptr[rulebook_index] = kernel_index - 1;
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi {
namespace sparse {
......@@ -34,22 +34,27 @@ template <typename T, typename IntT = int>
void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
const IntT* rulebook_ptr = rulebook.data<IntT>();
const int rulebook_len = rulebook.dims()[1];
int rulebook_len = 0;
const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr<IntT>(
out, rulebook, key, &rulebook_len);
const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key);
DenseTensorMeta in_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
......@@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
&x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0);
for (int i = 0; i < rulebook_len; i++) {
counter[rulebook_ptr[i]] += 1;
}
IntT offset = 0, max_count = 0;
std::vector<IntT> offsets(kernel_size + 1);
IntT offset = 0;
int max_count = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter[i];
offset += counter_ptr[i];
if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]);
max_count = std::max(max_count, counter_ptr[i]);
}
}
offsets[kernel_size] = offset;
......@@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue;
}
const int M = counter[i];
const int M = counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
......@@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
// 4. scatter
Scatter<T, IntT>(d_x_features_ptr,
rulebook.data<IntT>() + rulebook_len,
rulebook_ptr + rulebook_len,
rulebook_len,
in_channels,
x_grad_values_ptr);
......@@ -181,13 +184,16 @@ template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES(
......@@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradCPUKernel<T, data_t>(dev_ctx,
x,
kernel,
out,
rulebook,
counter,
out_grad,
paddings,
dilations,
strides,
groups,
subm,
key,
x_grad,
kernel_grad);
}));
......
......@@ -14,9 +14,10 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi {
namespace sparse {
......@@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
// update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight)
......@@ -66,10 +69,30 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
DenseTensor h_counter, h_offsets;
h_counter.Resize({kernel_size});
h_offsets.Resize({kernel_size + 1});
int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter);
int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets);
// DenseTensor* rulebook = nullptr;
const IntT* rulebook_ptr = nullptr;
int n = 0;
bool need_product_rulebook = true;
if (subm && !key.empty()) {
rulebook_ptr = phi::funcs::sparse::PrepareSubm<T, IntT, CPUContext>(
dev_ctx,
x,
key,
out_dims,
out,
h_counter_ptr,
h_offsets_ptr,
&n,
&need_product_rulebook);
}
if (need_product_rulebook) {
DenseTensor tmp_rulebook;
ProductRuleBook<T, CPUContext, IntT>(dev_ctx,
x,
kernel_sizes,
......@@ -78,14 +101,18 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
subm_strides,
out_dims,
subm,
rulebook,
&counter_per_kernel);
&tmp_rulebook,
h_counter_ptr);
UpdateRulebookAndOutIndex<T, CPUContext, IntT>(
dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out);
dev_ctx, x, kernel_size, out_channels, out_dims, &tmp_rulebook, out);
n = tmp_rulebook.dims()[1];
rulebook_ptr = tmp_rulebook.data<IntT>();
int n = rulebook->dims()[1];
const int* counter_ptr = counter_per_kernel.data<int>();
phi::funcs::sparse::SaveToTable(
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
// int n = rulebook->dims()[1];
// 2. gather
DenseTensorMeta in_features_meta(
......@@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
T* out_features_ptr = out_features.data<T>();
Gather<T, IntT>(x.non_zero_elements().data<T>(),
rulebook->data<IntT>() + n,
rulebook_ptr + n,
n,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
std::vector<int> offsets(kernel_size + 1);
int offset = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter_ptr[i];
h_offsets_ptr[i] = offset;
offset += h_counter_ptr[i];
}
offsets[kernel_size] = offset;
h_offsets_ptr[kernel_size] = offset;
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0) {
if (h_counter_ptr[i] <= 0) {
continue;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = counter_ptr[i];
const int M = h_counter_ptr[i];
const int K = in_channels; // in_channels
const int N = out_channels; // out_channels
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
......@@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// 4. scatter
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
Scatter<T, IntT>(out_features_ptr,
rulebook->data<IntT>() + n * 2,
n,
out_channels,
out_values_ptr);
Scatter<T, IntT>(
out_features_ptr, rulebook_ptr + n * 2, n, out_channels, out_values_ptr);
}
template <typename T, typename Context>
......@@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] {
Conv3dCooCPUKernel<T, data_t>(dev_ctx,
......@@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides,
groups,
subm,
key,
out,
rulebook);
rulebook,
counter);
}));
}
......
......@@ -28,6 +28,7 @@ template <typename T, typename IntT = int>
void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
......@@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
const int channels = x.dims()[4];
int rulebook_len = rulebook.dims()[1];
const IntT* rulebook_ptr = rulebook.data<IntT>();
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0);
for (int i = 0; i < rulebook_len; i++) {
counter[rulebook_ptr[i]] += 1;
}
phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size);
std::vector<int> offsets(kernel_size + 1);
const int* counter_ptr = counter.data<int>();
phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
const T* in_features_ptr = x.non_zero_elements().data<T>();
const T* out_features_ptr = out.non_zero_elements().data<T>();
......@@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
phi::funcs::MaxPoolGrad<T> grad_functor;
for (int i = 0; i < kernel_size; i++) {
for (int j = 0; j < counter[i]; j++) {
for (int j = 0; j < counter_ptr[i]; j++) {
IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j];
IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j];
for (int c = 0; c < channels; c++) {
......@@ -78,6 +78,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
......@@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] {
MaxPoolCooGradCPUKernel<T, data_t>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad);
dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
}));
}
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace phi {
namespace sparse {
......@@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
const auto& x_dims = x.dims();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
const std::vector<int>& real_kernel_sizes =
......@@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims);
const int in_channels = real_kernel_sizes[3];
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
std::vector<int> counter_per_kernel(kernel_size, 0);
const T* in_features_ptr = x.non_zero_elements().data<T>();
// 1. product rule book
......@@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
out_dims,
false,
rulebook,
&counter_per_kernel);
counter_per_kernel.data());
UpdateRulebookAndOutIndex<T, CPUContext, IntT>(
dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out);
int rulebook_len = rulebook->dims()[1];
const IntT* rulebook_ptr = rulebook->data<IntT>();
const int* counter_ptr = counter_per_kernel.data<int>();
counter->Resize({kernel_size});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, counter_per_kernel.data(), kernel_size * sizeof(int));
std::vector<int> offsets(kernel_size + 1);
phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
......@@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] {
MaxPoolCooCPUKernel<T, data_t>(dev_ctx,
......@@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations,
strides,
out,
rulebook);
rulebook,
counter);
}));
}
......
......@@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
}
// 5. scatter the values
const int VecSize = VecBytes / sizeof(T);
if (stride % VecSize == 0) {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, nnz * stride / VecSize, 1);
phi::funcs::sparse::ScatterKernel<T, VecSize>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
phi::funcs::sparse::ScatterKernel<T>
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
x_values_ptr,
phi::funcs::sparse::ScatterKernel<T, 1>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
}
// 6. convert index to coordinate
Dim<DDim::kMaxRank> const_dims;
......
此差异已折叠。
......@@ -19,13 +19,11 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
namespace phi {
namespace sparse {
......@@ -42,43 +40,42 @@ template <typename T, typename IntT>
void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
const IntT* rulebook_ptr = rulebook.data<IntT>();
const int rulebook_len = rulebook.dims()[1];
int rulebook_len = 0;
const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr<IntT>(
out, rulebook, key, &rulebook_len);
const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key);
DenseTensorMeta in_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta d_x_features_meta(
x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW);
DenseTensorMeta out_grad_features_meta(
x.dtype(), {rulebook_len, out_channels}, DataLayout::NCHW);
phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor d_x_features =
phi::Empty(dev_ctx, std::move(d_x_features_meta));
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_grad_features =
phi::Empty(dev_ctx, std::move(out_grad_features_meta));
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>();
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
T* d_kernel_ptr = kernel_grad->data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, kernel_grad, static_cast<T>(0.0f));
phi::backends::gpu::GpuMemsetAsync(
d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream());
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
......@@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
T* x_grad_values_ptr = x_grad_values.data<T>();
set_zero(dev_ctx, &x_grad_values, static_cast<T>(0.0f));
set_zero(dev_ctx, &d_x_features, static_cast<T>(0.0f));
phi::backends::gpu::GpuMemsetAsync(x_grad_values_ptr,
0,
sizeof(T) * x_grad_values.numel(),
dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(
d_x_features_ptr, 0, sizeof(T) * d_x_features.numel(), dev_ctx.stream());
phi::Copy<GPUContext>(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
......@@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
&x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
rulebook_ptr,
rulebook_len * sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
std::vector<int> offsets(kernel_size + 1);
for (int i = 0; i < rulebook_len; i++) {
counter[h_counter[i]] += 1;
}
IntT offset = 0, max_count = 0;
int offset = 0, max_count = 0;
for (int i = 0; i < kernel_size; i++) {
offsets[i] = offset;
offset += counter[i];
offset += counter_ptr[i];
if (i < half_kernel_size) {
max_count = std::max(max_count, counter[i]);
max_count = std::max(max_count, counter_ptr[i]);
}
}
offsets[kernel_size] = offset;
......@@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
}
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
DenseTensor unique_value = phi::Empty<int>(
dev_ctx, {static_cast<int>(x_grad->nnz() * kernel_size * 2)});
DenseTensor out_index =
phi::Empty<int>(dev_ctx, {static_cast<int>(x.nnz() * 2)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, IntT>
<<<config.block_per_grid.x,
config.thread_per_block.x,
GroupIndexsV2<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
dev_ctx.stream()>>>(rulebook_len,
x.nnz(),
kernel_size,
offsets[kernel_size / 2],
rulebook_ptr,
out_index_ptr,
unique_value_ptr);
GatherV2<T, IntT>(dev_ctx,
x.non_zero_elements().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);
Gather<T, IntT>(dev_ctx,
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
out_channels);
out_channels,
out_grad_features_ptr);
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0 || (subm && i == half_kernel_size)) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue;
}
const int M = counter[i];
const int M = counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
......@@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
// 4. scatter
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
phi::funcs::ScatterCUDAKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
d_x_features_ptr,
rulebook_ptr + rulebook_len,
x_grad_values_ptr,
rulebook_len,
out_index.data<int>(),
unique_value.data<int>(),
x_grad->nnz(),
kernel_size,
in_channels,
false);
2,
x_grad_values_ptr);
}
template <typename T, typename Context>
void Conv3dCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES(
......@@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradGPUKernel<T, data_t>(dev_ctx,
x,
kernel,
out,
rulebook,
counter,
out_grad,
paddings,
dilations,
strides,
groups,
subm,
key,
x_grad,
kernel_grad);
}));
......
......@@ -21,7 +21,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "glog/logging.h"
namespace phi {
namespace sparse {
......@@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
// update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight)
......@@ -61,22 +65,40 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims);
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size);
DenseTensor h_counter, h_offsets;
h_counter.Resize({kernel_size});
h_offsets.Resize({kernel_size + 1});
int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter);
int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets);
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensorMeta offsets_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta));
DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));
int n = ProductRuleBook<T, GPUContext, IntT>(dev_ctx,
DenseTensor counter_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size});
DenseTensor offsets_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size});
DenseTensor out_index = phi::Empty<int>(dev_ctx, {1});
DenseTensor unique_value = phi::Empty<int>(dev_ctx, {1});
VLOG(6) << "call SubmConv3D or Conv3D " << subm << " and the key is " << key;
int rulebook_len = 0;
const IntT* rulebook_ptr = nullptr;
bool need_product_rulebook = true;
if (subm && !key.empty()) {
rulebook_ptr = phi::funcs::sparse::PrepareSubm<T, IntT, GPUContext>(
dev_ctx,
x,
key,
out_dims,
out,
h_counter.data<int>(),
h_offsets.data<int>(),
&rulebook_len,
&need_product_rulebook);
}
if (need_product_rulebook) {
DenseTensor tmp_rulebook;
rulebook_len = ProductRuleBook<T, GPUContext, IntT>(dev_ctx,
x,
kernel_sizes,
subm_paddings,
......@@ -84,62 +106,76 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
subm_strides,
out_dims,
subm,
rulebook,
&tmp_rulebook,
&counter_per_kernel,
&offsets_per_kernel,
&out_index,
&unique_value,
out,
&h_counter,
&offsets);
h_counter_ptr,
h_offsets_ptr);
rulebook_ptr = tmp_rulebook.data<IntT>();
const int* counter_ptr = counter_per_kernel.data<int>();
const int* offsets_ptr = counter_per_kernel.data<int>();
const IntT* rulebook_ptr = rulebook->data<IntT>();
phi::funcs::sparse::SaveToTable(
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
// 2. gather
DenseTensorMeta in_features_meta(
x.dtype(), {n, in_channels}, DataLayout::NCHW);
DenseTensorMeta out_features_meta(
x.dtype(), {n, out_channels}, DataLayout::NCHW);
phi::DenseTensor in_features =
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta));
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1);
GatherKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
Gather<T, IntT>(dev_ctx,
x.non_zero_elements().data<T>(),
rulebook_ptr,
rulebook_len,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter[i] <= 0) {
if (h_counter_ptr[i] <= 0) {
continue;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = h_counter[i];
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
......@@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
}
// 4. scatter
if (subm) {
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1);
phi::funcs::ScatterCUDAKernel<T, IntT>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(out_features_ptr,
rulebook_ptr + 2 * n,
out_values_ptr,
n,
out_channels,
false);
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels, 1);
phi::funcs::sparse::ScatterKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
unique_value.data<int>(),
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
out_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
out->nnz(),
n,
kernel_size,
out_channels,
1,
out_values_ptr);
}
}
/**
* x: (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
* out: (N, D, H, W, OC)
* x: the input SparseCooTensor, shape is (N, D, H, W, C)
* kernel: the weight data, shape is (D, H, W, C, OC)
* out: the output SparseCooTensor, shape is (N, D, H, W, OC)
* rulebook: return rulebook if key is not vailed else return nullptr
* counter: return counter if key is not vailed else return nullptr
**/
template <typename T, typename Context>
void Conv3dCooKernel(const Context& dev_ctx,
......@@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooGPUKernel", ([&] {
Conv3dCooGPUKernel<T, data_t>(dev_ctx,
......@@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides,
groups,
subm,
key,
out,
rulebook);
rulebook,
counter);
}));
}
......
......@@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
x_indexs_ptr, x_indexs.numel(), table.data<int>());
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
const int VecBytes = 16;
const int VecSize = VecBytes / sizeof(T);
if (stride % VecSize == 0) {
......
......@@ -55,6 +55,7 @@ template <typename T, typename IntT = int>
void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
......@@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
const int in_channels = x.dims()[4];
int rulebook_len = rulebook.dims()[1];
const IntT* rulebook_ptr = rulebook.data<IntT>();
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
rulebook_ptr,
rulebook_len * sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < rulebook_len; i++) {
counter[h_counter[i]] += 1;
}
phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size);
std::vector<int> offsets(kernel_size + 1);
const int* counter_ptr = counter.data<int>();
phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size);
const T* in_features_ptr = x.non_zero_elements().data<T>();
const T* out_features_ptr = out.non_zero_elements().data<T>();
......@@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
&x_grad_indices);
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) {
if (counter_ptr[i] <= 0) {
continue;
}
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, counter[i] * in_channels, 1);
dev_ctx, counter_ptr[i] * in_channels, 1);
MaxPoolGradCudaKernel<T, IntT>
<<<config.block_per_grid.x,
config.thread_per_block.x,
......@@ -112,8 +99,8 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
dev_ctx.stream()>>>(in_features_ptr,
out_features_ptr,
out_grad_ptr,
rulebook_ptr + offsets[i] + rulebook_len,
counter[i],
rulebook_ptr + offsets[i],
counter_ptr[i],
rulebook_len,
in_channels,
x_grad_ptr);
......@@ -124,6 +111,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
......@@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradGPUKernel", ([&] {
MaxPoolCooGradGPUKernel<T, data_t>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad);
dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
}));
}
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
namespace phi {
namespace sparse {
......@@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
const auto& x_dims = x.dims();
int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];
const std::vector<int>& real_kernel_sizes =
......@@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims);
const int in_channels = real_kernel_sizes[3];
std::vector<int> offsets(kernel_size + 1), counter(kernel_size);
std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size);
DenseTensorMeta counter_meta(
DataType::INT32, {kernel_size}, DataLayout::NCHW);
DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta));
......@@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
&out_index,
&unique_value,
out,
&counter,
&offsets);
h_counter.data(),
offsets.data());
const IntT* rulebook_ptr = rulebook->data<IntT>();
T* out_features_ptr = out->mutable_non_zero_elements()->data<T>();
const T* in_features_ptr = x.non_zero_elements().data<T>();
counter->Resize({kernel_size});
int* counter_ptr = dev_ctx.template HostAlloc<int>(counter);
memcpy(counter_ptr, h_counter.data(), h_counter.size() * sizeof(int));
// 2. max pool
#ifdef PADDLE_WITH_HIP
thrust::fill(thrust::hip::par.on(dev_ctx.stream()),
......@@ -107,19 +111,18 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
static_cast<T>(0));
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) {
if (h_counter[i] <= 0) {
continue;
}
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, counter[i] * in_channels, 1);
MaxPoolCudaKernel<T, IntT>
<<<config.block_per_grid.x,
dev_ctx, h_counter[i] * in_channels, 1);
MaxPoolCudaKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(in_features_ptr,
rulebook_ptr + offsets[i] + rulebook_len,
counter[i],
rulebook_ptr + offsets[i],
h_counter[i],
rulebook_len,
in_channels,
out_features_ptr);
......@@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
SparseCooTensor* out,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGPUKernel", ([&] {
MaxPoolCooGPUKernel<T, data_t>(dev_ctx,
......@@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations,
strides,
out,
rulebook);
rulebook,
counter);
}));
}
......
......@@ -25,6 +25,7 @@ template <typename T, typename Context>
void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
......@@ -34,12 +35,13 @@ template <typename T, typename Context>
SparseCooTensor MaxPoolCooGrad(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes) {
SparseCooTensor x_grad;
MaxPoolCooGradKernel<T, Context>(
dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad);
dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, &x_grad);
return x_grad;
}
......
......@@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
SparseCooTensor* out,
DenseTensor* rulebook);
DenseTensor* rulebook,
DenseTensor* counter);
template <typename T, typename Context>
SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
......@@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
DenseTensor* rulebook) {
DenseTensor* rulebook,
DenseTensor* counter) {
SparseCooTensor coo;
MaxPoolCooKernel<T, Context>(
dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook);
MaxPoolCooKernel<T, Context>(dev_ctx,
x,
kernel_sizes,
paddings,
dilations,
strides,
&coo,
rulebook,
counter);
return coo;
}
......
......@@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector<int>& indices,
kernel.size() * sizeof(T));
if (!std::is_same<T, phi::dtype::float16>::value) {
auto tensor_out = paddle::experimental::sparse::conv3d(
x, weight, paddings, dilations, strides, 1, false);
auto tensor_out = paddle::experimental::sparse::conv3d_coo(
x, weight, paddings, dilations, strides, 1, false, "Conv3d");
auto out =
std::dynamic_pointer_cast<phi::SparseCooTensor>(tensor_out.impl());
......
......@@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
};
if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook = phi::Empty(
dev_ctx_cpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW));
DenseTensor rulebook, counter;
SparseCooTensor out = sparse::Conv3dCoo<T>(dev_ctx_cpu,
x_tensor,
kernel_tensor,
......@@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides,
1,
subm,
&rulebook);
"Conv3d",
&rulebook,
&counter);
ASSERT_EQ(correct_out_dims.size(), out.dims().size());
for (int i = 0; i < correct_out_dims.size(); i++) {
......@@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse::Conv3dCooGrad<T>(dev_ctx_cpu,
x_tensor,
kernel_tensor,
out,
rulebook,
counter,
out,
paddings,
dilations,
strides,
1,
subm);
subm,
"Conv3d");
f_verify(std::get<0>(grads).non_zero_elements().data<T>(), features_grad);
f_verify(std::get<1>(grads).data<T>(), kernel_grad);
}
......@@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
phi::Copy(
dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor);
DenseTensor d_rulebook = phi::Empty(
dev_ctx_gpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW));
DenseTensor d_rulebook, d_counter;
SparseCooTensor d_out = sparse::Conv3dCoo<T>(dev_ctx_gpu,
d_x_tensor,
d_kernel_tensor,
......@@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides,
1,
subm,
&d_rulebook);
"Conv3d",
&d_rulebook,
&d_counter);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
......@@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse::Conv3dCooGrad<T>(dev_ctx_gpu,
d_x_tensor,
d_kernel_tensor,
d_out,
d_rulebook,
d_counter,
d_out,
paddings,
dilations,
strides,
1,
subm);
subm,
"Conv3d");
DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements();
DenseTensor d_kernel_grad = std::get<1>(grads);
DenseTensor h_features_grad =
......
......@@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
};
if (!std::is_same<T, phi::dtype::float16>::value) {
DenseTensor rulebook;
DenseTensor rulebook, counter;
SparseCooTensor out = sparse::MaxPoolCoo<T>(dev_ctx_cpu,
x_tensor,
kernel_sizes,
paddings,
dilations,
strides,
&rulebook);
&rulebook,
&counter);
ASSERT_EQ(correct_out_dims.size(), out.dims().size());
for (int i = 0; i < correct_out_dims.size(); i++) {
......@@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
if (backward) {
SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>(
dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes);
dev_ctx_cpu, x_tensor, rulebook, counter, out, out, kernel_sizes);
f_verify(x_grad.non_zero_elements().data<T>(), features_grad);
}
}
......@@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims);
DenseTensor d_rulebook;
DenseTensor d_rulebook, d_counter;
SparseCooTensor d_out = sparse::MaxPoolCoo<T>(dev_ctx_gpu,
d_x_tensor,
kernel_sizes,
paddings,
dilations,
strides,
&d_rulebook);
&d_rulebook,
&d_counter);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
......@@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
f_verify(h_features_tensor.data<T>(), correct_out_features);
if (backward) {
SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>(
dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes);
SparseCooTensor x_grad = sparse::MaxPoolCooGrad<T>(dev_ctx_gpu,
d_x_tensor,
d_rulebook,
d_counter,
d_out,
d_out,
kernel_sizes);
DenseTensor h_features_grad =
phi::EmptyLike<T>(dev_ctx_cpu, x_grad.non_zero_elements());
phi::Copy(dev_ctx_gpu,
......
......@@ -67,7 +67,7 @@ class TestSparseConv(unittest.TestCase):
indices, values, dense_shape, stop_gradient=True)
weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32')
y = paddle.incubate.sparse.nn.functional.subm_conv3d(
sparse_x, weight)
sparse_x, weight, key='subm_conv')
assert np.array_equal(sparse_x.indices().numpy(),
y.indices().numpy())
......@@ -91,7 +91,7 @@ class TestSparseConv(unittest.TestCase):
with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC'
conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW')
1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv')
def test_SubmConv3D(self):
with _test_eager_guard():
......@@ -105,7 +105,7 @@ class TestSparseConv(unittest.TestCase):
indices, values, dense_shape, False)
subm_conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NDHWC')
1, 1, (1, 3, 3), data_format='NDHWC', key='subm_conv')
# test extra_repr
print(subm_conv3d.extra_repr())
......@@ -117,7 +117,7 @@ class TestSparseConv(unittest.TestCase):
with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC'
conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW')
1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv')
def test_Conv3D_bias(self):
with _test_eager_guard():
......
......@@ -29,6 +29,7 @@ def _conv3d(x,
dilation=1,
groups=1,
subm=False,
key=None,
data_format="NDHWC",
name=None):
assert in_dynamic_mode(), "Currently, only support dynamic mode"
......@@ -62,8 +63,9 @@ def _conv3d(x,
dilation = convert_to_list(dilation, dims, 'dilation')
op_type = "conv3d"
pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation,
stride, groups, subm)
pre_bias = _C_ops.final_state_sparse_conv3d_coo(
x, weight, padding, dilation, stride, groups, subm,
key if key is not None else "")
if bias is not None:
values = pre_bias.values()
add_bias = elementwise_add(values, bias, axis=1)
......@@ -186,7 +188,7 @@ def conv3d(x,
# (1, 1, 1, 2, 1)
"""
return _conv3d(x, weight, bias, stride, padding, dilation, groups, False,
data_format, name)
None, data_format, name)
def subm_conv3d(x,
......@@ -197,6 +199,7 @@ def subm_conv3d(x,
dilation=1,
groups=1,
data_format="NDHWC",
key=None,
name=None):
r"""
......@@ -274,6 +277,10 @@ def subm_conv3d(x,
will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of:
`[batch_size, input_depth, input_height, input_width, input_channels]`.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
name(str|None): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
......@@ -301,4 +308,4 @@ def subm_conv3d(x,
#(1, 1, 3, 4, 1)
"""
return _conv3d(x, weight, bias, stride, padding, dilation, groups, True,
data_format, name)
key, data_format, name)
......@@ -33,6 +33,7 @@ class _Conv3D(Layer):
dilation=1,
groups=1,
subm=False,
key=None,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
......@@ -46,6 +47,7 @@ class _Conv3D(Layer):
self._out_channels = out_channels
self._data_format = data_format
self._subm = subm
self._key = key
assert padding_mode == 'zeros', "Currently, only support padding_mode='zeros'"
assert groups == 1, "Currently, only support groups=1"
......@@ -95,6 +97,7 @@ class _Conv3D(Layer):
dilation=self._dilation,
groups=self._groups,
subm=self._subm,
key=self._key,
data_format=self._data_format)
return out
......@@ -240,6 +243,7 @@ class Conv3D(_Conv3D):
dilation=dilation,
groups=groups,
subm=False,
key=None,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=bias_attr,
......@@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D):
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. The default value is 1.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv3d. If it is set to None or one attribute of ParamAttr, conv3d
will create ParamAttr as param_attr. If it is set to None, the parameter
......@@ -361,6 +369,7 @@ class SubmConv3D(_Conv3D):
dilation=1,
groups=1,
padding_mode='zeros',
key=None,
weight_attr=None,
bias_attr=None,
data_format="NDHWC"):
......@@ -372,6 +381,7 @@ class SubmConv3D(_Conv3D):
dilation=dilation,
groups=groups,
subm=True,
key=key,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=bias_attr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册