Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
038883fd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
038883fd
编写于
7月 20, 2021
作者:
李
李季
提交者:
GitHub
7月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix cast op that can not cast the arrays that the size of arrays is beyond int32 (#34209)
* fix cast
上级
c8fb6fc4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
3 addition
and
2 deletion
+3
-2
paddle/fluid/operators/cast_op.cu
paddle/fluid/operators/cast_op.cu
+2
-1
paddle/fluid/platform/gpu_launch_config.h
paddle/fluid/platform/gpu_launch_config.h
+1
-1
未找到文件。
paddle/fluid/operators/cast_op.cu
浏览文件 @
038883fd
...
@@ -40,7 +40,8 @@ __global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
...
@@ -40,7 +40,8 @@ __global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
AlignedVector
<
InT
,
VecSize
>
;
using
LoadT
=
AlignedVector
<
InT
,
VecSize
>
;
using
StoreT
=
AlignedVector
<
OutT
,
VecSize
>
;
using
StoreT
=
AlignedVector
<
OutT
,
VecSize
>
;
for
(
int
i
=
idx
*
VecSize
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
for
(
int64_t
i
=
idx
*
VecSize
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
InT
in_vec
[
VecSize
];
InT
in_vec
[
VecSize
];
LoadT
*
in_value
=
reinterpret_cast
<
LoadT
*>
(
&
in_vec
);
LoadT
*
in_value
=
reinterpret_cast
<
LoadT
*>
(
&
in_vec
);
*
in_value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
i
]);
*
in_value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
i
]);
...
...
paddle/fluid/platform/gpu_launch_config.h
浏览文件 @
038883fd
...
@@ -41,7 +41,7 @@ struct GpuLaunchConfig {
...
@@ -41,7 +41,7 @@ struct GpuLaunchConfig {
};
};
inline
GpuLaunchConfig
GetGpuLaunchConfig1D
(
inline
GpuLaunchConfig
GetGpuLaunchConfig1D
(
const
platform
::
CUDADeviceContext
&
context
,
int
element_count
,
const
platform
::
CUDADeviceContext
&
context
,
int
64_t
element_count
,
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
// HIP will throw GPU memory access fault if threads > 256
// HIP will throw GPU memory access fault if threads > 256
int
max_threads
=
256
)
{
int
max_threads
=
256
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录