Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
393b3bd6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
393b3bd6
编写于
3月 31, 2021
作者:
T
Thunderbrook
提交者:
GitHub
3月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix split core (#31892)
* fix split core * format
上级
3a95a0bc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
24 deletion
+24
-24
paddle/fluid/operators/math/concat_and_split.cu
paddle/fluid/operators/math/concat_and_split.cu
+24
-24
未找到文件。
paddle/fluid/operators/math/concat_and_split.cu
浏览文件 @
393b3bd6
...
@@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num,
...
@@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
64_t
in_row
,
const
int
in_col
,
const
in
t
*
out_cols
,
const
int
64_t
in_col
,
const
int64_
t
*
out_cols
,
int
out_cols_size
,
T
**
outputs_data
)
{
int
out_cols_size
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
curr_segment
=
0
;
int
curr_segment
=
0
;
...
@@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row,
...
@@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
64_t
in_row
,
const
int
in_col
,
const
in
t
fixed_out_col
,
const
int
64_t
in_col
,
const
int64_
t
fixed_out_col
,
T
**
outputs_data
)
{
T
**
outputs_data
)
{
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
64_t
in_row
,
const
int
in_col
,
const
in
t
fixed_out_col
,
const
int
64_t
in_col
,
const
int64_
t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
)
{
T
*
outputs_addr0
,
T
*
outputs_addr1
)
{
T
*
outputs_data
[
2
];
T
*
outputs_data
[
2
];
outputs_data
[
0
]
=
outputs_addr0
;
outputs_data
[
0
]
=
outputs_addr0
;
...
@@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
...
@@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
64_t
in_row
,
const
int
in_col
,
const
in
t
fixed_out_col
,
const
int
64_t
in_col
,
const
int64_
t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr2
)
{
T
*
outputs_addr2
)
{
T
*
outputs_data
[
3
];
T
*
outputs_data
[
3
];
...
@@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
...
@@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
64_t
in_row
,
const
int
in_col
,
const
in
t
fixed_out_col
,
const
int
64_t
in_col
,
const
int64_
t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr2
,
T
*
outputs_addr3
)
{
T
*
outputs_addr2
,
T
*
outputs_addr3
)
{
T
*
outputs_data
[
4
];
T
*
outputs_data
[
4
];
...
@@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
...
@@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}
}
static
inline
void
GetBlockDims
(
const
platform
::
CUDADeviceContext
&
context
,
static
inline
void
GetBlockDims
(
const
platform
::
CUDADeviceContext
&
context
,
int
num_rows
,
int
num_cols
,
dim3
*
block_dim
s
,
int
64_t
num_rows
,
int64_t
num_col
s
,
dim3
*
grid_dims
)
{
dim3
*
block_dims
,
dim3
*
grid_dims
)
{
// Set the thread block and grid according to CurrentDeviceId
// Set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
int
block_cols
=
kThreadsPerBlock
;
...
@@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context,
...
@@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context,
*
block_dims
=
dim3
(
block_cols
,
block_rows
,
1
);
*
block_dims
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
64_t
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
int
grid_cols
=
std
::
min
((
num_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
std
::
min
((
num_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
num_rows
/
block_rows
,
1
));
std
::
max
(
num_rows
/
block_rows
,
(
int64_t
)
1
));
*
grid_dims
=
dim3
(
grid_cols
,
grid_rows
,
1
);
*
grid_dims
=
dim3
(
grid_cols
,
grid_rows
,
1
);
}
}
...
@@ -319,22 +319,22 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
...
@@ -319,22 +319,22 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int
axis
,
std
::
vector
<
framework
::
Tensor
*>*
outputs
)
{
int
axis
,
std
::
vector
<
framework
::
Tensor
*>*
outputs
)
{
// TODO(zcd): Add input data validity checking
// TODO(zcd): Add input data validity checking
int
o_num
=
outputs
->
size
();
int
o_num
=
outputs
->
size
();
int
out_row
=
1
;
int
64_t
out_row
=
1
;
auto
dim_0
=
ref_inputs
[
0
]
->
dims
();
auto
dim_0
=
ref_inputs
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
out_row
*=
dim_0
[
i
];
out_row
*=
dim_0
[
i
];
}
}
int
out0_col
=
ref_inputs
[
0
]
->
numel
()
/
out_row
;
int
64_t
out0_col
=
ref_inputs
[
0
]
->
numel
()
/
out_row
;
int
in_col
=
0
,
in_row
=
out_row
;
int
64_t
in_col
=
0
,
in_row
=
out_row
;
bool
has_same_shape
=
true
;
bool
has_same_shape
=
true
;
std
::
vector
<
T
*>
outputs_data
(
o_num
);
std
::
vector
<
T
*>
outputs_data
(
o_num
);
std
::
vector
<
int
>
outputs_cols
(
o_num
+
1
);
std
::
vector
<
int
64_t
>
outputs_cols
(
o_num
+
1
);
outputs_cols
[
0
]
=
0
;
outputs_cols
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
o_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
o_num
;
++
i
)
{
int
t_col
=
ref_inputs
.
at
(
i
)
->
numel
()
/
out_row
;
int
64_t
t_col
=
ref_inputs
.
at
(
i
)
->
numel
()
/
out_row
;
if
(
has_same_shape
)
{
if
(
has_same_shape
)
{
if
(
t_col
!=
out0_col
)
has_same_shape
=
false
;
if
(
t_col
!=
out0_col
)
has_same_shape
=
false
;
}
}
...
@@ -384,13 +384,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
...
@@ -384,13 +384,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
auto
tmp_dev_ins_col_data
=
auto
tmp_dev_ins_col_data
=
memory
::
Alloc
(
context
,
memory
::
Alloc
(
context
,
outputs_cols
.
size
()
*
sizeof
(
int
));
outputs_cols
.
size
()
*
sizeof
(
int
64_t
));
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()),
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()),
tmp_dev_ins_col_data
->
ptr
(),
platform
::
CPUPlace
(),
tmp_dev_ins_col_data
->
ptr
(),
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
outputs_cols
.
data
()),
reinterpret_cast
<
void
*>
(
outputs_cols
.
data
()),
outputs_cols
.
size
()
*
sizeof
(
int
),
context
.
stream
());
outputs_cols
.
size
()
*
sizeof
(
int
64_t
),
context
.
stream
());
int
*
dev_outs_col_data
=
int
64_t
*
dev_outs_col_data
=
reinterpret_cast
<
int
*>
(
tmp_dev_ins_col_data
->
ptr
());
reinterpret_cast
<
int
64_t
*>
(
tmp_dev_ins_col_data
->
ptr
());
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
dev_outs_col_data
,
input
.
data
<
T
>
(),
in_row
,
in_col
,
dev_outs_col_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录