Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fdf63b4e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fdf63b4e
编写于
4月 13, 2021
作者:
J
jiangcheng
提交者:
GitHub
4月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize check_finite_and_unscale_op by fused kernel, test=develop (#31954)
上级
4a09c1a1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
84 addition
and
21 deletion
+84
-21
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
+84
-21
未找到文件。
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
浏览文件 @
fdf63b4e
...
...
@@ -26,18 +26,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
}
template
<
typename
T
,
typename
MT
>
__global__
void
CheckFiniteAndUnscale
(
const
T
*
in
,
const
MT
*
scale
,
int
num
,
bool
*
found_inf
,
T
*
out
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
num
)
{
MT
val
=
static_cast
<
MT
>
(
in
[
idx
])
*
(
*
scale
);
__global__
void
CheckFiniteAndUnscale
(
const
T
**
xs
,
const
MT
*
scale
,
int64_t
size
,
int64_t
*
starts
,
bool
*
found_inf
,
T
**
outs
)
{
const
int64_t
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
// copy starts array from global memory to shared memory
extern
__shared__
int64_t
s_starts
[];
for
(
int
i
=
threadIdx
.
x
;
i
<=
size
;
i
+=
blockDim
.
x
)
{
s_starts
[
i
]
=
starts
[
i
];
}
__syncthreads
();
const
int64_t
num
=
s_starts
[
size
];
int
pre_xs_index
=
0
;
bool
t_found_inf
=
false
;
const
MT
t_scale
=
*
scale
;
for
(
int64_t
idx
=
tid
;
idx
<
num
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
// get the xs's index of thread
int
xs_index
=
pre_xs_index
;
while
(
idx
<
s_starts
[
xs_index
])
xs_index
++
;
// avoid some tensor's numel is zero
while
(
idx
>=
s_starts
[
xs_index
])
xs_index
++
;
pre_xs_index
=
xs_index
-
1
;
// get in data and out data
const
T
*
in
=
xs
[
pre_xs_index
];
T
*
out
=
outs
[
pre_xs_index
];
int64_t
in_idx
=
idx
-
s_starts
[
pre_xs_index
];
// Unscale
MT
val
=
static_cast
<
MT
>
(
in
[
in_idx
])
*
t_scale
;
T
narrow_val
=
static_cast
<
T
>
(
val
);
out
[
idx
]
=
narrow_val
;
out
[
in_idx
]
=
narrow_val
;
// CheckFinite
if
(
!
isfinite
(
narrow_val
))
{
*
found_inf
=
true
;
t_
found_inf
=
true
;
}
}
if
(
t_found_inf
)
{
*
found_inf
=
true
;
}
}
template
<
typename
T
>
...
...
@@ -63,20 +93,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
InverseAndMemset
<
MPDType
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
scale_data
,
inverse_scale_v
,
found_inf_data
);
for
(
size_t
i
=
0
;
i
<
xs
.
size
();
++
i
)
{
const
auto
*
x
=
xs
[
i
];
auto
*
out
=
outs
[
i
];
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
num
=
x
->
numel
();
int
block
=
1024
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
VLOG
(
3
)
<<
"launch kernel"
;
CheckFiniteAndUnscale
<
T
,
MPDType
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
inverse_scale_v
,
num
,
found_inf_data
,
out_data
);
VLOG
(
3
)
<<
"finish kernel"
;
size_t
xs_size
=
xs
.
size
();
// calculate each tensor's start index and copy to device
auto
h_starts_tensor
=
memory
::
Alloc
(
platform
::
CPUPlace
(),
(
xs_size
+
1
)
*
sizeof
(
int64_t
));
int64_t
*
h_starts
=
reinterpret_cast
<
int64_t
*>
(
h_starts_tensor
->
ptr
());
auto
d_starts_tensor
=
memory
::
Alloc
(
dev_ctx
,
(
xs_size
+
1
)
*
sizeof
(
int64_t
));
int64_t
*
d_starts
=
reinterpret_cast
<
int64_t
*>
(
d_starts_tensor
->
ptr
());
h_starts
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
xs_size
;
i
++
)
{
// the start index value of each tensor is
// the sum of previous tensor's size
h_starts
[
i
]
=
h_starts
[
i
-
1
]
+
xs
[
i
-
1
]
->
numel
();
}
int64_t
total_num
=
h_starts
[
xs_size
];
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
()),
d_starts
,
platform
::
CPUPlace
(),
h_starts
,
(
xs_size
+
1
)
*
sizeof
(
int64_t
),
dev_ctx
.
stream
());
// copy each tensor's data address to device
auto
h_mem
=
memory
::
Alloc
(
platform
::
CPUPlace
(),
2
*
xs_size
*
sizeof
(
T
*
));
const
T
**
h_xs
=
reinterpret_cast
<
const
T
**>
(
h_mem
->
ptr
());
T
**
h_outs
=
reinterpret_cast
<
T
**>
(
h_mem
->
ptr
())
+
xs_size
;
auto
d_mem
=
memory
::
Alloc
(
dev_ctx
,
2
*
xs_size
*
sizeof
(
T
*
));
const
T
**
d_xs
=
reinterpret_cast
<
const
T
**>
(
d_mem
->
ptr
());
T
**
d_outs
=
reinterpret_cast
<
T
**>
(
d_mem
->
ptr
())
+
xs_size
;
for
(
size_t
i
=
0
;
i
<
xs_size
;
++
i
)
{
h_xs
[
i
]
=
xs
[
i
]
->
data
<
T
>
();
h_outs
[
i
]
=
outs
[
i
]
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
}
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
()),
d_xs
,
platform
::
CPUPlace
(),
h_xs
,
2
*
xs_size
*
sizeof
(
T
*
),
dev_ctx
.
stream
());
// Launch Kernel
int
block
=
1024
;
int
block_num
=
block
*
20
;
// each thread deal with 20 number
int
grid
=
(
total_num
+
block_num
-
1
)
/
block_num
;
VLOG
(
3
)
<<
"launch kernel"
;
CheckFiniteAndUnscale
<
T
,
MPDType
><<<
grid
,
block
,
(
xs_size
+
1
)
*
sizeof
(
int64_t
),
dev_ctx
.
stream
()
>>>
(
d_xs
,
inverse_scale_v
,
xs_size
,
d_starts
,
found_inf_data
,
d_outs
);
VLOG
(
3
)
<<
"finish kernel"
;
}
};
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录