Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8082ba8a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8082ba8a
编写于
3月 29, 2023
作者:
S
ShenLiang
提交者:
GitHub
3月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[BugFix] fix compute error in fused_dropout_add (#52261)
* fix bg * add utest * add utest
上级
73df2b1e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
30 deletion
+22
-30
paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
...e/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
+15
-22
python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
...paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
+7
-8
未找到文件。
paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
浏览文件 @
8082ba8a
...
@@ -91,9 +91,9 @@ struct NoMaskBwFunctor {
...
@@ -91,9 +91,9 @@ struct NoMaskBwFunctor {
template
<
typename
T
,
typename
Functor
>
template
<
typename
T
,
typename
Functor
>
__global__
void
VectorizedDropoutBackward
(
const
size_t
n
,
__global__
void
VectorizedDropoutBackward
(
const
size_t
n
,
uint64_t
seed
,
uint64_t
seed
,
T
*
src
,
T
*
x
,
T
*
res
,
T
*
y
,
const
T
*
dst
,
const
T
*
out_grad
,
uint64_t
increment
,
uint64_t
increment
,
size_t
main_offset
,
size_t
main_offset
,
Functor
functor
)
{
Functor
functor
)
{
...
@@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n,
...
@@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n,
#endif
#endif
float
rands
[
kCount
];
float
rands
[
kCount
];
T
src_res
[
kCount
*
2
];
T
x_y
[
kCount
*
2
];
T
res_grad
[
kCount
];
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
using
Cast
=
kps
::
IdentityFunctor
<
T
>
;
using
Cast
=
kps
::
IdentityFunctor
<
T
>
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
size_t
fix
=
idx
*
kCount
;
size_t
fix
=
idx
*
kCount
;
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
false
>
(
&
src_res
[
0
],
dst
,
deal_size
);
kps
::
ReadData
<
T
,
kCount
,
1
,
false
>
(
&
x_y
[
0
],
out_grad
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
&
rands
[
0
],
Rand
(),
&
state
);
// x_grad
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
src_res
[
0
],
&
src_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
&
x_y
[
0
],
&
x_y
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
src
+
fix
,
&
src_res
[
0
],
deal_size
);
// res
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
x
+
fix
,
&
x_y
[
0
],
deal_size
);
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
Cast
>
(
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
y
+
fix
,
&
x_y
[
kCount
],
deal_size
);
&
res_grad
[
0
],
&
src_res
[
kCount
],
Cast
());
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
res
+
fix
,
&
res_grad
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
__syncthreads
();
}
}
}
}
int
remainder
=
n
-
fix
;
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
true
>
(
&
src_res
[
0
],
dst
+
fix
,
remainder
);
kps
::
ReadData
<
T
,
kCount
,
1
,
true
>
(
&
x_y
[
0
],
out_grad
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
&
rands
[
0
],
Rand
(),
&
state
);
// x_grad
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
src_res
[
0
],
&
src_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
&
x_y
[
0
],
&
x_y
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
src
+
fix
,
&
src_res
[
0
],
remainder
);
// res
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
x
+
fix
,
&
x_y
[
0
],
remainder
);
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
Cast
>
(
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
y
+
fix
,
&
x_y
[
kCount
],
remainder
);
&
res_grad
[
0
],
&
src_res
[
kCount
],
Cast
());
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
res
+
fix
,
&
res_grad
[
0
],
remainder
);
__syncthreads
();
__syncthreads
();
}
}
}
}
...
@@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
...
@@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
size_t
block_size
=
random_prop
[
1
];
size_t
block_size
=
random_prop
[
1
];
size_t
offset
=
random_prop
[
2
];
size_t
offset
=
random_prop
[
2
];
size_t
main_offset
=
random_prop
[
3
];
size_t
main_offset
=
random_prop
[
3
];
auto
functor
=
upscale_in_train
auto
functor
=
upscale_in_train
?
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
)
?
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
)
:
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
,
1.0
f
);
:
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
,
1.0
f
);
...
...
python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
浏览文件 @
8082ba8a
...
@@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
...
@@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
)
)
class
TestFusedDropoutAdd
(
unittest
.
TestCase
):
class
TestFusedDropoutAdd
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
shape
=
(
2
,
10
,
10
,
2
)
self
.
shape
=
[
2
,
1024
,
2
,
1
]
self
.
dtype
=
'float
64
'
self
.
dtype
=
'float
16
'
self
.
dropout_rate
=
0.
9
self
.
dropout_rate
=
0.
5
self
.
training
=
True
self
.
training
=
True
self
.
mode
=
"upscale_in_train"
self
.
mode
=
"upscale_in_train"
self
.
seed
=
1027
self
.
seed
=
1027
...
@@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase):
...
@@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase):
mode
=
self
.
mode
,
mode
=
self
.
mode
,
)
)
fw
.
append
(
out
)
fw
.
append
(
out
)
out_g
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
loss
=
paddle
.
mean
(
out
)
paddle
.
autograd
.
backward
([
out
],
[
out_g
],
True
)
loss
.
backward
()
for
i
in
range
(
count
):
for
i
in
range
(
count
):
bw
.
append
(
data
[
i
].
grad
)
bw
.
append
(
data
[
i
].
grad
)
return
fw
,
bw
return
fw
,
bw
...
@@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed):
...
@@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed):
)
)
class
TestFusedDropoutAddCase
(
parent
):
class
TestFusedDropoutAddCase
(
parent
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
shape
=
(
2
,
10
,
10
,
2
)
self
.
shape
=
(
2
,
10
24
,
1
,
1
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
dropout_rate
=
p
self
.
dropout_rate
=
p
self
.
training
=
training
self
.
training
=
training
...
@@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase):
...
@@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase):
y
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
y
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
fused_d_a
=
FusedDropoutAdd
(
p
=
0.5
)
fused_d_a
=
FusedDropoutAdd
(
p
=
0.5
)
d
=
paddle
.
nn
.
Dropout
(
p
=
0.5
)
d
=
paddle
.
nn
.
Dropout
(
p
=
0.5
)
print
(
d
)
print
(
d
.
extra_repr
()
)
paddle
.
seed
(
2048
)
paddle
.
seed
(
2048
)
fused_out
=
fused_d_a
(
x
,
y
)
fused_out
=
fused_d_a
(
x
,
y
)
paddle
.
seed
(
2048
)
paddle
.
seed
(
2048
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录