Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2f5fb031
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2f5fb031
编写于
3月 16, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
3月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Restructure sparse conv (#40570)
restructure conv
上级
3898080e
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
555 addition
and
552 deletion
+555
-552
paddle/phi/kernels/funcs/sparse/convolution.h
paddle/phi/kernels/funcs/sparse/convolution.h
+10
-10
paddle/phi/kernels/sparse/convolution_grad_kernel.h
paddle/phi/kernels/sparse/convolution_grad_kernel.h
+2
-2
paddle/phi/kernels/sparse/cpu/convolution.h
paddle/phi/kernels/sparse/cpu/convolution.h
+7
-7
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
+2
-2
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
+7
-2
paddle/phi/kernels/sparse/gpu/convolution.cu.h
paddle/phi/kernels/sparse/gpu/convolution.cu.h
+493
-0
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
+6
-7
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
+6
-502
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
+22
-20
未找到文件。
paddle/phi/kernels/funcs/sparse/convolution.h
浏览文件 @
2f5fb031
...
...
@@ -93,7 +93,7 @@ inline HOSTDEVICE void IndexToPoint(
}
inline
void
GetOutShape
(
const
DDim
&
x_dims
,
const
DDim
&
kernel_dim
s
,
const
std
::
vector
<
int
>&
kernel_size
s
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
@@ -102,17 +102,17 @@ inline void GetOutShape(const DDim& x_dims,
x_dims
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, D, H, W, C)"
));
PADDLE_ENFORCE_EQ
(
kernel_
dim
s
.
size
(),
PADDLE_ENFORCE_EQ
(
kernel_
size
s
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"the shape of kernel should be (D, H, W, C, OC)"
));
// infer out shape
(
*
out_dims
)[
0
]
=
x_dims
[
0
];
(
*
out_dims
)[
4
]
=
kernel_
dim
s
[
4
];
(
*
out_dims
)[
4
]
=
kernel_
size
s
[
4
];
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
(
*
out_dims
)[
i
]
=
(
x_dims
[
i
]
+
2
*
paddings
[
i
-
1
]
-
dilations
[
i
-
1
]
*
(
kernel_
dim
s
[
i
-
1
]
-
1
)
-
1
)
/
dilations
[
i
-
1
]
*
(
kernel_
size
s
[
i
-
1
]
-
1
)
-
1
)
/
strides
[
i
-
1
]
+
1
;
}
...
...
@@ -131,7 +131,7 @@ template <typename T, typename Context>
inline
void
SubmPreProcess
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
SparseCoo
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
const
int
in_channels
,
const
int
out_channels
,
const
int
half_kernel_size
,
...
...
@@ -142,11 +142,11 @@ inline void SubmPreProcess(const Context& dev_ctx,
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
x
.
non_zero_elements
().
dims
()[
1
],
out_grad
.
non_zero_elements
().
dims
()[
1
],
out_grad
.
dims
()[
1
],
x
.
non_zero_elements
().
dims
()[
0
],
static_cast
<
T
>
(
1
),
x
.
non_zero_elements
().
data
<
T
>
(),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
out_grad
.
data
<
T
>
(),
static_cast
<
T
>
(
0
),
d_kernel_ptr
+
half_kernel_size
*
in_channels
*
out_channels
);
...
...
@@ -155,11 +155,11 @@ inline void SubmPreProcess(const Context& dev_ctx,
T
*
x_grad_ptr
=
x_grad
->
data
<
T
>
();
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
out_grad
.
non_zero_elements
().
dims
()[
0
],
out_grad
.
dims
()[
0
],
in_channels
,
out_grad
.
non_zero_elements
().
dims
()[
1
],
out_grad
.
dims
()[
1
],
static_cast
<
T
>
(
1
),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
out_grad
.
data
<
T
>
(),
kernel
.
data
<
T
>
()
+
half_kernel_size
*
in_channels
*
out_channels
,
static_cast
<
T
>
(
0
),
x_grad_ptr
);
...
...
paddle/phi/kernels/sparse/convolution_grad_kernel.h
浏览文件 @
2f5fb031
...
...
@@ -27,7 +27,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
kernel
,
const
SparseCoo
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
@@ -41,7 +41,7 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
kernel
,
const
SparseCoo
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
paddle/phi/kernels/sparse/cpu/convolution.h
浏览文件 @
2f5fb031
...
...
@@ -34,7 +34,7 @@ using Dims4D = phi::funcs::sparse::Dims4D;
template
<
typename
T
,
typename
Context
>
void
ProductRuleBook
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
@@ -42,19 +42,19 @@ void ProductRuleBook(const Context& dev_ctx,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_
dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dim
s
[
2
];
int
kernel_size
=
kernel_
sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_size
s
[
2
];
memset
(
counter_ptr
,
0
,
kernel_size
*
sizeof
(
int
));
int
rulebook_len
=
0
;
// calc the rulebook_len
const
auto
&
x_dims
=
x
.
dims
();
const
Dims4D
c_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
const
Dims4D
c_kernel_dims
(
1
,
kernel_dims
[
2
],
kernel_dims
[
1
],
kernel_dims
[
0
]);
const
Dims4D
c_kernel_dims
(
1
,
kernel_sizes
[
2
],
kernel_sizes
[
1
],
kernel_sizes
[
0
]);
const
Dims4D
c_out_dims
(
out_dims
[
0
],
out_dims
[
3
],
out_dims
[
2
],
out_dims
[
1
]);
const
Dims4D
c_paddings
(
1
,
paddings
[
2
],
paddings
[
1
],
paddings
[
0
]);
const
Dims4D
c_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
...
...
@@ -75,9 +75,9 @@ void ProductRuleBook(const Context& dev_ctx,
auto
f_calc_rulebook
=
[
&
](
int
*
rulebook_ptr
)
{
int
kernel_index
=
0
,
rulebook_index
=
0
;
for
(
int
kz
=
0
;
kz
<
kernel_
dim
s
[
0
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_
dim
s
[
1
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_
dim
s
[
2
];
kx
++
)
{
for
(
int
kz
=
0
;
kz
<
kernel_
size
s
[
0
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_
size
s
[
1
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_
size
s
[
2
];
kx
++
)
{
++
kernel_index
;
for
(
int64_t
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
int
batch
=
indices_ptr
[
i
];
...
...
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
浏览文件 @
2f5fb031
...
...
@@ -33,7 +33,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
kernel
,
const
SparseCoo
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
@@ -113,7 +113,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
rulebook_len
,
in_channels
,
in_features_ptr
);
Gather
<
T
>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
Gather
<
T
>
(
out_grad
.
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
rulebook_len
,
out_channels
,
...
...
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
浏览文件 @
2f5fb031
...
...
@@ -44,8 +44,13 @@ void Conv3dKernel(const Context& dev_ctx,
const
auto
&
kernel_dims
=
kernel
.
dims
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
std
::
vector
<
int
>
kernel_sizes
(
kernel_dims
.
size
());
for
(
int
i
=
0
;
i
<
kernel_dims
.
size
();
i
++
)
{
kernel_sizes
[
i
]
=
kernel_dims
[
i
];
}
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
kernel_
dim
s
,
paddings
,
dilations
,
strides
,
&
out_dims
);
x_dims
,
kernel_
size
s
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
...
...
@@ -63,7 +68,7 @@ void Conv3dKernel(const Context& dev_ctx,
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
kernel
,
kernel
_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
...
...
paddle/phi/kernels/sparse/gpu/convolution.cu.h
浏览文件 @
2f5fb031
...
...
@@ -23,11 +23,15 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
namespace
phi
{
namespace
sparse
{
using
Dims4D
=
phi
::
funcs
::
sparse
::
Dims4D
;
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
...
...
@@ -139,5 +143,494 @@ inline int* SortedAndUniqueIndex(const Context& dev_ctx,
return
new_end
.
first
;
}
template
<
typename
T
>
__global__
void
SetFlagAndUpdateCounterKernel
(
const
int
*
indexs
,
const
int
n
,
const
int
rulebook_len
,
const
int
kernel_size
,
T
*
rulebook_ptr
,
int
*
counter_ptr
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
cache_count
[];
// kernel_size
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
cache_count
[
i
]
=
0
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
n
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
index
=
indexs
[
i
];
int
kernel_index
=
rulebook_ptr
[
index
];
rulebook_ptr
[
index
+
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
+
2
*
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
]
=
-
1
;
atomicAdd
(
&
cache_count
[
kernel_index
],
1
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
atomicSub
(
&
counter_ptr
[
i
],
cache_count
[
i
]);
}
}
/**
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
* unique_values: indiates the index of key before deduplication
* out_indexs: indicates the position of the output index in the rulebook
* rulebook_len: indicates the length of rulebook
* out_dims: indicates the output dims
* out_indices: the indices of output, out_indices = IndexToPoint(unique_keys)
* rulebook_out_indexs: the output index in rulebook
**/
template
<
typename
T
>
__global__
void
UpdateIndexKernel
(
const
int
*
unique_keys
,
const
int
*
unique_values
,
const
int
*
out_indexs
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
Dims4D
out_dims
,
T
*
out_indices
,
T
*
rulebook_out_indexs
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
int
index
=
unique_keys
[
i
];
int
batch
,
x
,
y
,
z
;
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
// get out indices
out_indices
[
i
]
=
batch
;
out_indices
[
i
+
non_zero_num
]
=
z
;
out_indices
[
i
+
non_zero_num
*
2
]
=
y
;
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
// update rulebook
int
start
=
unique_values
[
i
];
int
end
=
i
==
non_zero_num
-
1
?
rulebook_len
:
unique_values
[
i
+
1
];
// max(end-start) = kernel_size
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
rulebook_out_indexs
[
out_indexs
[
j
]]
=
i
;
}
}
}
// brief: calculation the distance between start and end
template
<
typename
T
>
__global__
void
DistanceKernel
(
const
T
*
start
,
const
T
*
end
,
int
*
distance
)
{
if
(
threadIdx
.
x
==
0
)
{
*
distance
=
end
-
start
;
}
}
/**
* @brief product rulebook
* for input_i in x_indices:
* if input_i participate in the convolution calculation:
* infer the output_i by input_i and kernel_i
* save output_i
*
* x_indices: the indices of input features
* x_dims: the input dims
* kernel_dims: the kernel dims
* out_dims: the output dims
* non_zero_num: the number of input features
* rulebook: the rulebook to save the kernel index, input index and output index
* counter: save the number of times each location in the kernel participates in
*the caculation
**/
template
<
typename
T
>
__global__
void
ProductRuleBookKernel
(
const
T
*
x_indices
,
const
Dims4D
x_dims
,
const
Dims4D
kernel_dims
,
const
Dims4D
out_dims
,
const
int64_t
non_zero_num
,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
bool
subm
,
T
*
rulebook
,
int
*
counter
,
int
*
in_indexs
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
counter_buf
[];
// kernel_size
const
int
kernel_size
=
kernel_dims
[
3
]
*
kernel_dims
[
2
]
*
kernel_dims
[
1
];
const
int
offset
=
kernel_size
*
non_zero_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
counter_buf
[
i
]
=
0
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
kernel_index
=
0
;
int
batch
=
x_indices
[
i
];
int
in_z
=
x_indices
[
i
+
non_zero_num
];
int
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
int
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
if
(
subm
)
{
in_indexs
[
i
]
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
x_dims
);
}
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
int
in_i
=
-
1
,
out_index
=
-
1
,
kernel_i
=
-
1
;
if
(
phi
::
funcs
::
sparse
::
Check
(
x_dims
,
kernel_dims
,
paddings
,
dilations
,
strides
,
in_x
,
in_y
,
in_z
,
kx
,
ky
,
kz
))
{
int
out_z
=
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
int
out_y
=
(
in_y
+
paddings
[
2
]
-
ky
*
dilations
[
2
])
/
strides
[
2
];
int
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
in_i
=
i
;
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
kernel_i
=
kernel_index
;
}
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
kernel_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
*
2
+
i
]
=
out_index
;
++
kernel_index
;
}
}
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
atomicAdd
(
&
counter
[
i
],
counter_buf
[
i
]);
}
}
// the basic algorithm can refer to convolution_kernel.cc or
// the second paper
// example:
// 1. the rulebook:
// the kernel_index: 0, 0, 0, 1, 1, 1, 2, 2, ....
// the out_index(key): 20, 30, 33, 30, 33, 20, 25
// 2. mark the index of out_index(value): 0, 1, 2, 3, 4, 5, 6, ....
// 3. sorted the (key, value)
// 4. unique the (key, value):
// unique_key: 20, 25, 30, 33
// unique_values: 0, 2, 3, 5
// the index of unique_values is: 0, 1, 2, 3
// 5. update the out_index by unique_key, uniqe_value and the index of
// unique_value:
// the new out_index: 0, 2, 3, 2, 3, 0, 1
template
<
typename
T
,
typename
Context
>
int
ProductRuleBook
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
DDim
&
out_dims
,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
,
DenseTensor
*
offsets_per_kernel
,
DenseTensor
*
out_index
,
DenseTensor
*
unique_key
,
DenseTensor
*
unique_value
,
SparseCooTensor
*
out
,
std
::
vector
<
int
>*
h_counter
,
std
::
vector
<
int
>*
h_offsets
)
{
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
DenseTensor
in_indexs
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
int
rulebook_rows
=
3
;
const
int
rulebook_cols
=
kernel_size
*
non_zero_num
;
rulebook
->
ResizeAndAllocate
({
rulebook_rows
,
rulebook_cols
});
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
const
auto
x_dims
=
x
.
dims
();
Dims4D
d_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
Dims4D
d_kernel_dims
(
1
,
kernel_sizes
[
2
],
kernel_sizes
[
1
],
kernel_sizes
[
0
]);
Dims4D
d_out_dims
(
out_dims
[
0
],
out_dims
[
3
],
out_dims
[
2
],
out_dims
[
1
]);
Dims4D
d_paddings
(
1
,
paddings
[
2
],
paddings
[
1
],
paddings
[
0
]);
Dims4D
d_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
Dims4D
d_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
// 1. product rule book
phi
::
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
dev_ctx
,
counter_per_kernel
,
0
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
ProductRuleBookKernel
<
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
kernel_size
*
sizeof
(
int
),
dev_ctx
.
stream
()
>>>
(
indices_ptr
,
d_x_dims
,
d_kernel_dims
,
d_out_dims
,
non_zero_num
,
d_paddings
,
d_dilations
,
d_strides
,
subm
,
rulebook_ptr
,
counter_ptr
,
in_indexs
.
data
<
int
>
());
// 2. remove -1
#ifdef PADDLE_WITH_HIP
int
*
last
=
thrust
::
remove
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
,
-
1
);
DistanceKernel
<
int
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
);
int
rulebook_len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
rulebook_len
/=
3
;
dev_ctx
.
Wait
();
if
(
subm
)
{
// At present, hashtable is not used to map the input and output indexes.
// At present, the intermediate output index is generated by normal
// convolution,
// and then the intermediate output index is subtracted from the input index
// to obain the rulebook.
// get difference
int32_t
*
A_key_ptr
=
rulebook_ptr
+
2
*
rulebook_len
;
int32_t
*
B_key_ptr
=
in_indexs
.
data
<
int
>
();
DenseTensor
A_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
B_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
A_val
,
kps
::
IdentityFunctor
<
int
>
());
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
B_val
,
kps
::
IdentityFunctor
<
int
>
());
DenseTensor
key_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
+
1
},
DataLayout
::
NCHW
));
DenseTensor
val_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
counter_ptr
,
counter_ptr
+
kernel_size
,
offsets_ptr
);
std
::
vector
<
int
>
offsets
(
kernel_size
,
0
);
// TODO(zhangkaihuo): used unified memcpy interface
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
offsets
.
data
(),
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
thrust
::
pair
<
int
*
,
int
*>
end
;
// Because set_diff does not support duplicate data, set_diff is performed
// separately for each segment of data.
// TODO(zhangkaihuo): Using hashtable here may get better performance,
// further tests ared needed.
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
int
start
=
offsets
[
i
];
int
stop
=
i
==
kernel_size
-
1
?
rulebook_len
:
offsets
[
i
+
1
];
int
*
key_result_start
=
(
i
==
0
?
key_result
.
data
<
int
>
()
:
end
.
first
);
int
*
val_result_start
=
i
==
0
?
val_result
.
data
<
int
>
()
:
end
.
second
;
end
=
#ifdef PADDLE_WITH_HIP
thrust
::
set_difference_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
set_difference_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
A_key_ptr
+
start
,
A_key_ptr
+
stop
,
B_key_ptr
,
B_key_ptr
+
x
.
nnz
(),
A_val
.
data
<
int
>
()
+
start
,
B_val
.
data
<
int
>
(),
key_result_start
,
val_result_start
);
}
DistanceKernel
<
int
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
key_result
.
data
<
int
>
(),
end
.
first
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
int
len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
// set the diff value = -1, and update counter
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
len
,
1
);
SetFlagAndUpdateCounterKernel
<
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
,
kernel_size
*
sizeof
(
int
),
dev_ctx
.
stream
()
>>>
(
val_result
.
data
<
int
>
(),
len
,
rulebook_len
,
kernel_size
,
rulebook_ptr
,
counter_ptr
);
// remove -1
#ifdef PADDLE_WITH_HIP
int
*
last
=
thrust
::
remove
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
3
*
rulebook_len
,
-
1
);
DistanceKernel
<
int
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
rulebook_len
/=
3
;
}
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
counter_ptr
,
counter_ptr
+
kernel_size
,
offsets_ptr
);
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_counter
)[
0
],
counter_ptr
,
kernel_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_offsets
)[
0
],
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_counter
)[
0
],
counter_ptr
,
kernel_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_offsets
)[
0
],
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#endif
rulebook
->
Resize
({
rulebook_rows
,
rulebook_len
});
// 3. sorted or merge the out index
out_index
->
ResizeAndAllocate
({
rulebook_len
});
unique_value
->
ResizeAndAllocate
({
rulebook_len
});
unique_key
->
ResizeAndAllocate
({
rulebook_len
});
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
int
*
unique_key_ptr
=
unique_key
->
data
<
int
>
();
int
*
new_end
=
SortedAndUniqueIndex
(
dev_ctx
,
rulebook_ptr
+
2
*
rulebook_len
,
rulebook_len
,
out_index
,
unique_key
,
unique_value
);
// thrust::distance doesn't support stream parameters
// const int out_non_zero_num = thrust::distance(unique_key_ptr,
// new_end.first);
DistanceKernel
<
int
><<<
1
,
1
>>>
(
unique_key_ptr
,
new_end
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
);
int
out_non_zero_num
=
0
;
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#endif
dev_ctx
.
Wait
();
// 5. update out_indices and rulebook by unique_value_ptr
const
int64_t
sparse_dim
=
4
;
DenseTensorMeta
indices_meta
(
DataType
::
INT32
,
{
sparse_dim
,
out_non_zero_num
},
DataLayout
::
NCHW
);
DenseTensorMeta
values_meta
(
x
.
dtype
(),
{
out_non_zero_num
,
kernel_sizes
[
4
]},
x
.
layout
());
phi
::
DenseTensor
out_indices
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
indices_meta
));
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
int
*
out_indices_ptr
=
out_indices
.
data
<
int
>
();
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out_non_zero_num
,
1
);
UpdateIndexKernel
<
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
unique_key_ptr
,
unique_value_ptr
,
out_index_ptr
,
out_non_zero_num
,
rulebook_len
,
d_out_dims
,
out_indices_ptr
,
rulebook_ptr
+
2
*
rulebook_len
);
out
->
SetMember
(
out_indices
,
out_values
,
out_dims
,
true
);
return
rulebook_len
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
浏览文件 @
2f5fb031
...
...
@@ -38,7 +38,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
kernel
,
const
SparseCoo
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
...
...
@@ -140,8 +140,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
dev_ctx
.
stream
()
>>>
(
out_grad
.
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
out_grad_features_ptr
,
rulebook_len
,
...
...
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
浏览文件 @
2f5fb031
...
...
@@ -12,515 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace
phi
{
namespace
sparse
{
using
Dims4D
=
phi
::
funcs
::
sparse
::
Dims4D
;
__global__
void
SetFlagAndUpdateCounterKernel
(
const
int
*
indexs
,
const
int
n
,
const
int
rulebook_len
,
const
int
kernel_size
,
int
*
rulebook_ptr
,
int
*
counter_ptr
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
cache_count
[];
// kernel_size
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
cache_count
[
i
]
=
0
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
n
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
index
=
indexs
[
i
];
int
kernel_index
=
rulebook_ptr
[
index
];
rulebook_ptr
[
index
+
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
+
2
*
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
]
=
-
1
;
atomicAdd
(
&
cache_count
[
kernel_index
],
1
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
atomicSub
(
&
counter_ptr
[
i
],
cache_count
[
i
]);
}
}
/**
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
* unique_values: indiates the index of key before deduplication
* out_indexs: indicates the position of the output index in the rulebook
* rulebook_len: indicates the length of rulebook
* out_dims: indicates the output dims
* out_indices: the indices of output, out_indices = IndexToPoint(unique_keys)
* rulebook_out_indexs: the output index in rulebook
**/
__global__
void
UpdateIndexKernel
(
const
int
*
unique_keys
,
const
int
*
unique_values
,
const
int
*
out_indexs
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
Dims4D
out_dims
,
int
*
out_indices
,
int
*
rulebook_out_indexs
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
int
index
=
unique_keys
[
i
];
int
batch
,
x
,
y
,
z
;
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
// get out indices
out_indices
[
i
]
=
batch
;
out_indices
[
i
+
non_zero_num
]
=
z
;
out_indices
[
i
+
non_zero_num
*
2
]
=
y
;
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
// update rulebook
int
start
=
unique_values
[
i
];
int
end
=
i
==
non_zero_num
-
1
?
rulebook_len
:
unique_values
[
i
+
1
];
// max(end-start) = kernel_size
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
rulebook_out_indexs
[
out_indexs
[
j
]]
=
i
;
}
}
}
/**
* @brief product rulebook
* for input_i in x_indices:
* if input_i participate in the convolution calculation:
* infer the output_i by input_i and kernel_i
* save output_i
*
* x_indices: the indices of input features
* x_dims: the input dims
* kernel_dims: the kernel dims
* out_dims: the output dims
* non_zero_num: the number of input features
* rulebook: the rulebook to save the kernel index, input index and output index
* counter: save the number of times each location in the kernel participates in
*the caculation
**/
__global__
void
ProductRuleBookKernel
(
const
int
*
x_indices
,
const
Dims4D
x_dims
,
const
Dims4D
kernel_dims
,
const
Dims4D
out_dims
,
const
int64_t
non_zero_num
,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
bool
subm
,
int
*
rulebook
,
int
*
counter
,
int
*
in_indexs
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
counter_buf
[];
// kernel_size
const
int
kernel_size
=
kernel_dims
[
3
]
*
kernel_dims
[
2
]
*
kernel_dims
[
1
];
const
int
offset
=
kernel_size
*
non_zero_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
counter_buf
[
i
]
=
0
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
kernel_index
=
0
;
int
batch
=
x_indices
[
i
];
int
in_z
=
x_indices
[
i
+
non_zero_num
];
int
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
int
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
if
(
subm
)
{
in_indexs
[
i
]
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
x_dims
);
}
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
int
in_i
=
-
1
,
out_index
=
-
1
,
kernel_i
=
-
1
;
if
(
phi
::
funcs
::
sparse
::
Check
(
x_dims
,
kernel_dims
,
paddings
,
dilations
,
strides
,
in_x
,
in_y
,
in_z
,
kx
,
ky
,
kz
))
{
int
out_z
=
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
int
out_y
=
(
in_y
+
paddings
[
2
]
-
ky
*
dilations
[
2
])
/
strides
[
2
];
int
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
in_i
=
i
;
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
kernel_i
=
kernel_index
;
}
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
kernel_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
*
2
+
i
]
=
out_index
;
++
kernel_index
;
}
}
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
atomicAdd
(
&
counter
[
i
],
counter_buf
[
i
]);
}
}
// brief: calculation the distance between start and end
__global__
void
DistanceKernel
(
const
int
*
start
,
const
int
*
end
,
int
*
distance
)
{
if
(
threadIdx
.
x
==
0
)
{
*
distance
=
end
-
start
;
}
}
// the basic algorithm can refer to convolution_kernel.cc or
// the second paper
// example:
// 1. the rulebook:
// the kernel_index: 0, 0, 0, 1, 1, 1, 2, 2, ....
// the out_index(key): 20, 30, 33, 30, 33, 20, 25
// 2. mark the index of out_index(value): 0, 1, 2, 3, 4, 5, 6, ....
// 3. sorted the (key, value)
// 4. unique the (key, value):
// unique_key: 20, 25, 30, 33
// unique_values: 0, 2, 3, 5
// the index of unique_values is: 0, 1, 2, 3
// 5. update the out_index by unique_key, uniqe_value and the index of
// unique_value:
// the new out_index: 0, 2, 3, 2, 3, 0, 1
template
<
typename
T
,
typename
Context
>
int
ProductRuleBook
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
DDim
&
out_dims
,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
,
DenseTensor
*
offsets_per_kernel
,
DenseTensor
*
out_index
,
DenseTensor
*
unique_key
,
DenseTensor
*
unique_value
,
SparseCooTensor
*
out
,
std
::
vector
<
int
>*
h_counter
,
std
::
vector
<
int
>*
h_offsets
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
DenseTensor
in_indexs
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
rulebook_rows
=
3
;
const
int
rulebook_cols
=
kernel_size
*
non_zero_num
;
rulebook
->
ResizeAndAllocate
({
rulebook_rows
,
rulebook_cols
});
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
const
auto
x_dims
=
x
.
dims
();
Dims4D
d_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
Dims4D
d_kernel_dims
(
1
,
kernel_dims
[
2
],
kernel_dims
[
1
],
kernel_dims
[
0
]);
Dims4D
d_out_dims
(
out_dims
[
0
],
out_dims
[
3
],
out_dims
[
2
],
out_dims
[
1
]);
Dims4D
d_paddings
(
1
,
paddings
[
2
],
paddings
[
1
],
paddings
[
0
]);
Dims4D
d_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
Dims4D
d_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
// 1. product rule book
phi
::
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
dev_ctx
,
counter_per_kernel
,
0
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
ProductRuleBookKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
kernel_size
*
sizeof
(
int
),
dev_ctx
.
stream
()
>>>
(
indices_ptr
,
d_x_dims
,
d_kernel_dims
,
d_out_dims
,
non_zero_num
,
d_paddings
,
d_dilations
,
d_strides
,
subm
,
rulebook_ptr
,
counter_ptr
,
in_indexs
.
data
<
int
>
());
// 2. remove -1
#ifdef PADDLE_WITH_HIP
int
*
last
=
thrust
::
remove
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
,
-
1
);
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
);
int
rulebook_len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
rulebook_len
/=
3
;
dev_ctx
.
Wait
();
if
(
subm
)
{
// At present, hashtable is not used to map the input and output indexes.
// At present, the intermediate output index is generated by normal
// convolution,
// and then the intermediate output index is subtracted from the input index
// to obain the rulebook.
// get difference
int32_t
*
A_key_ptr
=
rulebook_ptr
+
2
*
rulebook_len
;
int32_t
*
B_key_ptr
=
in_indexs
.
data
<
int
>
();
DenseTensor
A_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
B_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
A_val
,
kps
::
IdentityFunctor
<
int
>
());
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
B_val
,
kps
::
IdentityFunctor
<
int
>
());
DenseTensor
key_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
+
1
},
DataLayout
::
NCHW
));
DenseTensor
val_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
counter_ptr
,
counter_ptr
+
kernel_size
,
offsets_ptr
);
std
::
vector
<
int
>
offsets
(
kernel_size
,
0
);
// TODO(zhangkaihuo): used unified memcpy interface
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
offsets
.
data
(),
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
thrust
::
pair
<
int
*
,
int
*>
end
;
// Because set_diff does not support duplicate data, set_diff is performed
// separately for each segment of data.
// TODO(zhangkaihuo): Using hashtable here may get better performance,
// further tests ared needed.
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
int
start
=
offsets
[
i
];
int
stop
=
i
==
kernel_size
-
1
?
rulebook_len
:
offsets
[
i
+
1
];
int
*
key_result_start
=
(
i
==
0
?
key_result
.
data
<
int
>
()
:
end
.
first
);
int
*
val_result_start
=
i
==
0
?
val_result
.
data
<
int
>
()
:
end
.
second
;
end
=
#ifdef PADDLE_WITH_HIP
thrust
::
set_difference_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
set_difference_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
A_key_ptr
+
start
,
A_key_ptr
+
stop
,
B_key_ptr
,
B_key_ptr
+
x
.
nnz
(),
A_val
.
data
<
int
>
()
+
start
,
B_val
.
data
<
int
>
(),
key_result_start
,
val_result_start
);
}
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
key_result
.
data
<
int
>
(),
end
.
first
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
int
len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
// set the diff value = -1, and update counter
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
len
,
1
);
SetFlagAndUpdateCounterKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
,
kernel_size
*
sizeof
(
int
),
dev_ctx
.
stream
()
>>>
(
val_result
.
data
<
int
>
(),
len
,
rulebook_len
,
kernel_size
,
rulebook_ptr
,
counter_ptr
);
// remove -1
#ifdef PADDLE_WITH_HIP
int
*
last
=
thrust
::
remove
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
3
*
rulebook_len
,
-
1
);
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
rulebook_len
/=
3
;
}
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
counter_ptr
,
counter_ptr
+
kernel_size
,
offsets_ptr
);
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_counter
)[
0
],
counter_ptr
,
kernel_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_offsets
)[
0
],
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_counter
)[
0
],
counter_ptr
,
kernel_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
(
*
h_offsets
)[
0
],
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#endif
rulebook
->
Resize
({
rulebook_rows
,
rulebook_len
});
// 3. sorted or merge the out index
out_index
->
ResizeAndAllocate
({
rulebook_len
});
unique_value
->
ResizeAndAllocate
({
rulebook_len
});
unique_key
->
ResizeAndAllocate
({
rulebook_len
});
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
int
*
unique_key_ptr
=
unique_key
->
data
<
int
>
();
int
*
new_end
=
SortedAndUniqueIndex
(
dev_ctx
,
rulebook_ptr
+
2
*
rulebook_len
,
rulebook_len
,
out_index
,
unique_key
,
unique_value
);
// thrust::distance doesn't support stream parameters
// const int out_non_zero_num = thrust::distance(unique_key_ptr,
// new_end.first);
DistanceKernel
<<<
1
,
1
>>>
(
unique_key_ptr
,
new_end
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
);
int
out_non_zero_num
=
0
;
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#endif
dev_ctx
.
Wait
();
// 5. update out_indices and rulebook by unique_value_ptr
const
int64_t
sparse_dim
=
4
;
DenseTensorMeta
indices_meta
(
DataType
::
INT32
,
{
sparse_dim
,
out_non_zero_num
},
DataLayout
::
NCHW
);
DenseTensorMeta
values_meta
(
x
.
dtype
(),
{
out_non_zero_num
,
kernel_dims
[
4
]},
x
.
layout
());
phi
::
DenseTensor
out_indices
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
indices_meta
));
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
int
*
out_indices_ptr
=
out_indices
.
data
<
int
>
();
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out_non_zero_num
,
1
);
UpdateIndexKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
unique_key_ptr
,
unique_value_ptr
,
out_index_ptr
,
out_non_zero_num
,
rulebook_len
,
d_out_dims
,
out_indices_ptr
,
rulebook_ptr
+
2
*
rulebook_len
);
out
->
SetMember
(
out_indices
,
out_values
,
out_dims
,
true
);
return
rulebook_len
;
}
/**
* x: (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
...
...
@@ -545,9 +46,12 @@ void Conv3dKernel(const Context& dev_ctx,
const
auto
&
kernel_dims
=
kernel
.
dims
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
std
::
vector
<
int
>
kernel_sizes
(
kernel_dims
.
size
());
for
(
int
i
=
0
;
i
<
kernel_dims
.
size
();
i
++
)
{
kernel_sizes
[
i
]
=
kernel_dims
[
i
];
}
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
kernel_dims
,
paddings
,
dilations
,
strides
,
&
out_dims
);
out
->
set_dims
(
out_dims
);
x_dims
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
h_counter
(
kernel_size
);
...
...
@@ -574,7 +78,7 @@ void Conv3dKernel(const Context& dev_ctx,
int
n
=
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
kernel
,
kernel
_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
...
...
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
浏览文件 @
2f5fb031
...
...
@@ -132,11 +132,12 @@ void TestConv3dBase(const std::vector<int>& indices,
f_verify
(
out
.
non_zero_elements
().
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
std
::
vector
<
DenseTensor
>
grads
=
sparse
::
Conv3dGrad
<
T
>
(
dev_ctx_cpu
,
std
::
vector
<
DenseTensor
>
grads
=
sparse
::
Conv3dGrad
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
rulebook
,
kernel_tensor
,
out
,
out
.
non_zero_elements
()
,
paddings
,
dilations
,
strides
,
...
...
@@ -231,11 +232,12 @@ void TestConv3dBase(const std::vector<int>& indices,
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
std
::
vector
<
DenseTensor
>
grads
=
sparse
::
Conv3dGrad
<
T
>
(
dev_ctx_gpu
,
std
::
vector
<
DenseTensor
>
grads
=
sparse
::
Conv3dGrad
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_kernel_tensor
,
d_out
,
d_out
.
non_zero_elements
()
,
paddings
,
dilations
,
strides
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录