Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c7382df8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7382df8
编写于
12月 03, 2018
作者:
Y
Yibing Liu
提交者:
GitHub
12月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Print assert failure id in lookup_table_op (#14698)
上级
566a3259
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
5 deletion
+15
-5
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+5
-5
paddle/fluid/platform/assert.h
paddle/fluid/platform/assert.h
+10
-0
未找到文件。
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
c7382df8
...
...
@@ -31,8 +31,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while
(
idy
<
K
)
{
int64_t
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
_MSG_CODE
(
id
>=
0
,
"received id:"
,
id
);
PADDLE_ASSERT
_MSG_CODE
(
id
<
N
,
"received id:"
,
id
);
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
...
...
@@ -57,9 +57,9 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
int
64_t
id
=
ids
[
idy
];
PADDLE_ASSERT
_MSG_CODE
(
id
>=
0
,
"received id:"
,
id
);
PADDLE_ASSERT
_MSG_CODE
(
id
<
N
,
"received id:"
,
id
);
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
...
...
paddle/fluid/platform/assert.h
浏览文件 @
c7382df8
...
...
@@ -36,6 +36,15 @@ limitations under the License. */
asm("trap;"); \
} \
} while (0)
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %d).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \
asm("trap;"); \
} \
} while (0)
#else
#include <assert.h>
// For cuda, the assertions can affect performance and it is therefore
...
...
@@ -43,4 +52,5 @@ limitations under the License. */
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ASSERT(e) assert((e))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#define PADDLE_ASSERT_MSG_CODE(e, m, c) assert((e) && (m) && (c || 1))
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录