Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0a5d625b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a5d625b
编写于
7月 13, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
7月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Opt sparse mask_kernel (#44302)
* opt sparse_mask
上级
ae8ca764
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
84 addition
and
64 deletion
+84
-64
paddle/phi/kernels/sparse/cpu/mask_kernel.cc
paddle/phi/kernels/sparse/cpu/mask_kernel.cc
+1
-1
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
+82
-61
paddle/phi/kernels/sparse/mask_kernel.h
paddle/phi/kernels/sparse/mask_kernel.h
+0
-0
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
+0
-1
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
+1
-1
未找到文件。
paddle/phi/kernels/sparse/cpu/
sparse_
mask_kernel.cc
→
paddle/phi/kernels/sparse/cpu/mask_kernel.cc
浏览文件 @
0a5d625b
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/
sparse_
mask_kernel.h"
#include "paddle/phi/kernels/sparse/mask_kernel.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
...
...
paddle/phi/kernels/sparse/gpu/
sparse_
mask_kernel.cu
→
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
浏览文件 @
0a5d625b
...
@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
#include "paddle/phi/kernels/sparse/mask_kernel.h"
#include <thrust/binary_search.h>
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
...
@@ -24,6 +22,7 @@ limitations under the License. */
...
@@ -24,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
...
@@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
...
@@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
sparse_offsets
.
data
<
int64_t
>
(),
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
sparse_offsets
.
data
<
int64_t
>
(),
&
h_sparse_offsets
[
0
],
&
h_sparse_offsets
[
0
],
sizeof
(
int64_t
)
*
sparse_dim
,
sizeof
(
int64_t
)
*
sparse_dim
,
#ifdef PADDLE_WITH_HIP
gpuMemcpyHostToDevice
,
hipMemcpyHostToDevice
,
#else
cudaMemcpyHostToDevice
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
stream
());
DenseTensor
out_indices
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
indices
);
DenseTensor
out_indices
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
indices
);
...
@@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
...
@@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
*
cols
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
*
cols
,
1
);
MaskKernel
<
T
,
IntT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
>>>
(
MaskKernel
<
T
,
IntT
>
x_ptr
,
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
indices_ptr
,
x_ptr
,
sparse_offsets
.
data
<
int64_t
>
(),
indices_ptr
,
non_zero_num
,
sparse_offsets
.
data
<
int64_t
>
(),
cols
,
non_zero_num
,
sparse_dim
,
cols
,
out_values_ptr
);
sparse_dim
,
out_values_ptr
);
out
->
SetMember
(
out_indices
,
out_values
,
dims
,
true
);
out
->
SetMember
(
out_indices
,
out_values
,
dims
,
true
);
}
}
...
@@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx,
...
@@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx,
}));
}));
}
}
template
<
typename
T
,
typename
IntT
>
template
<
typename
IntT
>
__global__
void
SparseMaskCopyKernel
(
const
IntT
*
x_indexs
,
__global__
void
MaskTable
(
const
IntT
*
x_indexs
,
const
int
n
,
int
*
table
)
{
const
IntT
*
mask_indexs
,
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
const
IntT
*
bound_out
,
int
index
=
x_indexs
[
i
];
const
T
*
x_values
,
table
[
index
]
=
i
==
0
?
-
1
:
i
;
const
int64_t
n
,
}
const
int64_t
stride
,
}
T
*
out_values
)
{
template
<
typename
T
,
typename
IntT
,
int
VecSize
>
__global__
void
MaskCopy
(
const
IntT
*
mask_indexs
,
const
int
*
table
,
const
int
n
,
const
int
stride
,
const
T
*
x_values
,
T
*
out_values
)
{
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
const
IntT
j
=
bound_out
[
i
];
int
j
=
table
[
mask_indexs
[
i
]];
if
(
j
>=
0
&&
j
<
n
&&
mask_indexs
[
i
]
==
x_indexs
[
j
])
{
if
(
j
!=
0
)
{
for
(
int
k
=
0
;
k
<
stride
;
k
++
)
{
if
(
j
==
-
1
)
j
=
0
;
out_values
[
i
*
stride
+
k
]
=
x_values
[
j
*
stride
+
k
];
for
(
int
k
=
0
;
k
<
stride
;
k
+=
VecSize
)
{
LoadT
vec_x
;
phi
::
Load
<
T
,
VecSize
>
(
x_values
+
j
*
stride
+
k
,
&
vec_x
);
phi
::
Store
<
T
,
VecSize
>
(
vec_x
,
out_values
+
i
*
stride
+
k
);
}
}
}
}
}
}
...
@@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
d_sparse_offsets
.
data
<
IntT
>
(),
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
d_sparse_offsets
.
data
<
IntT
>
(),
sparse_offsets
.
data
(),
sparse_offsets
.
data
(),
sizeof
(
IntT
)
*
sparse_dim
,
sizeof
(
IntT
)
*
sparse_dim
,
#ifdef PADDLE_WITH_HIP
gpuMemcpyHostToDevice
,
hipMemcpyHostToDevice
,
#else
cudaMemcpyHostToDevice
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
stream
());
// 3. flatten x indices and mask indices
// 3. flatten x indices and mask indices
...
@@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
mask_indexs
.
numel
(),
mask_indexs
.
numel
(),
sparse_dim
,
sparse_dim
,
mask_indexs_ptr
);
mask_indexs_ptr
);
// 4. call thrust::lower_bound
#ifdef PADDLE_WITH_HIP
thrust
::
lower_bound
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
lower_bound
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
x_indexs_ptr
,
x_indexs_ptr
+
x_indexs
.
numel
(),
mask_indexs_ptr
,
mask_indexs_ptr
+
mask_indexs
.
numel
(),
bound_out_ptr
);
// 5. copy value to out
int
table_size
=
1
;
auto
x_dims
=
x
.
dims
();
for
(
int
i
=
0
;
i
<
x_dims
.
size
()
-
1
;
i
++
)
{
table_size
*=
x_dims
[
i
];
}
DenseTensor
table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
table
.
data
<
int
>
(),
0
,
table_size
*
sizeof
(
int
),
dev_ctx
.
stream
());
const
int64_t
stride
=
x
.
dims
().
size
()
==
sparse_dim
?
1
:
x
.
non_zero_elements
().
dims
()[
1
];
*
out
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
*
out
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0
));
T
*
out_ptr
=
out
->
data
<
T
>
();
T
*
out_ptr
=
out
->
data
<
T
>
();
config
=
const
int64_t
stride
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x_indexs
.
numel
(),
1
);
x
.
dims
().
size
()
==
sparse_dim
?
1
:
x
.
non_zero_elements
().
dims
()[
1
];
MaskTable
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
SparseMaskCopyKernel
<<<
config
.
block_per_grid
,
0
,
config
.
thread_per_block
,
dev_ctx
.
stream
()
>>>
(
0
,
x_indexs_ptr
,
x_indexs
.
numel
(),
table
.
data
<
int
>
());
dev_ctx
.
stream
()
>>>
(
x_indexs_ptr
,
config
=
mask_indexs_ptr
,
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
mask_indexs
.
numel
(),
1
);
bound_out_ptr
,
const
int
VecBytes
=
16
;
x
.
non_zero_elements
().
data
<
T
>
(),
const
int
VecSize
=
VecBytes
/
sizeof
(
T
);
mask_indexs
.
numel
(),
if
(
stride
%
VecSize
==
0
)
{
stride
,
MaskCopy
<
T
,
IntT
,
VecSize
>
out_ptr
);
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
table
.
data
<
int
>
(),
mask_indexs
.
numel
(),
stride
,
x
.
non_zero_elements
().
data
<
T
>
(),
out_ptr
);
}
else
{
MaskCopy
<
T
,
IntT
,
1
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
table
.
data
<
int
>
(),
mask_indexs
.
numel
(),
stride
,
x
.
non_zero_elements
().
data
<
T
>
(),
out_ptr
);
}
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
@@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx,
...
@@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx,
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_
mask
,
PD_REGISTER_KERNEL
(
mask
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseMaskKernel
,
phi
::
sparse
::
SparseMaskKernel
,
...
@@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask,
...
@@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask,
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
}
PD_REGISTER_KERNEL
(
sparse_
mask_helper
,
PD_REGISTER_KERNEL
(
mask_helper
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseMaskHelperKernel
,
phi
::
sparse
::
SparseMaskHelperKernel
,
...
...
paddle/phi/kernels/sparse/
sparse_
mask_kernel.h
→
paddle/phi/kernels/sparse/mask_kernel.h
浏览文件 @
0a5d625b
文件已移动
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
浏览文件 @
0a5d625b
...
@@ -15,7 +15,6 @@ limitations under the License. */
...
@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
...
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
浏览文件 @
0a5d625b
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/sparse/
sparse_
mask_kernel.h"
#include "paddle/phi/kernels/sparse/mask_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录