Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dca56f47
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dca56f47
编写于
8月 05, 2020
作者:
Z
Zhong Hui
提交者:
GitHub
8月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix invalid read of pnorm gradient function
fix invalid read of pnorm gradient function and delete the unused code
上级
2c9d0f3c
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
5 addition
and
19 deletion
+5
-19
paddle/fluid/operators/p_norm_op.cu
paddle/fluid/operators/p_norm_op.cu
+5
-19
未找到文件。
paddle/fluid/operators/p_norm_op.cu
浏览文件 @
dca56f47
...
...
@@ -99,39 +99,25 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
const
float
porder
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
const
T
eps
,
T
*
x_grad
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_sum
;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int
num
=
pre
*
post
;
auto
porder_grad
=
static_cast
<
T
>
(
porder
-
1.0
f
);
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
T
sum
=
0.0
;
__shared__
T
row_sum
;
__shared__
T
row_sqrt_norm
;
__shared__
T
row_norm
;
__shared__
T
pnorm_i
;
__shared__
T
yout_i
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
sum
+=
x
[
index
]
*
y_grad
[
index
];
}
T
reduce_result
=
BlockReduce
(
temp_storage_sum
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
{
row_sum
=
reduce_result
;
row_sqrt_norm
=
x_norm
[
i
];
row_norm
=
row_sqrt_norm
*
row_sqrt_norm
;
pnorm_i
=
x_norm
[
i
];
yout_i
=
y_grad
[
i
];
}
__syncthreads
();
const
T
pnorm_i
=
x_norm
[
i
];
const
T
yout_i
=
y_grad
[
i
];
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
inline_abs
(
x
[
index
]);
const
T
dy_ij
=
y_grad
[
index
];
x_grad
[
index
]
=
inline_pow
(
x_ij
,
porder_grad
)
/
(
inline_pow
(
pnorm_i
,
porder_grad
)
+
eps
)
*
yout_i
*
inline_sign
(
x
[
index
]);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录