Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
660aa8e6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
660aa8e6
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4958 Fix GPU-ArgMaxWithValue
Merge pull request !4958 from 34bunny/GPU-argmaxwithvalue-fix
上级
469b132c
92082038
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
23 addition
and
28 deletion
+23
-28
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu
...end/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu
+23
-28
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu
浏览文件 @
660aa8e6
...
...
@@ -18,39 +18,34 @@
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template
<
typename
T
,
typename
S
>
__global__
void
ArgmaxWithValue
(
const
T
*
input
,
const
int
bound
,
int
outerSize
,
int
innerSize
,
S
*
index
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
outerSize
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
inputOutterOffset
=
pos
*
innerSize
*
bound
;
int
outputOutterOffset
=
pos
*
innerSize
;
for
(
int
j
=
0
;
j
<
innerSize
;
j
++
)
{
auto
outputInnerOffset
=
outputOutterOffset
+
j
;
S
idx
=
0
;
T
maxData
=
input
[
j
+
inputOutterOffset
];
for
(
S
c
=
0
;
c
<
bound
;
c
++
)
{
int
offset
=
j
+
c
*
innerSize
;
auto
inputData
=
input
[
inputOutterOffset
+
offset
];
idx
=
inputData
>
maxData
?
c
:
idx
;
maxData
=
inputData
>
maxData
?
inputData
:
maxData
;
}
output
[
outputInnerOffset
]
=
maxData
;
index
[
outputInnerOffset
]
=
idx
;
}
__global__
void
ArgmaxWithValue
(
const
T
*
input
,
const
int
bound
,
int
outerSize
,
int
innerSize
,
S
*
index
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
outerSize
*
innerSize
;
pos
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
x
=
pos
/
innerSize
%
outerSize
;
int
y
=
pos
%
innerSize
;
S
idx
=
0
;
int
InputOffset
=
x
*
bound
*
innerSize
+
0
*
innerSize
+
y
;
T
maxData
=
input
[
InputOffset
];
for
(
int
i
=
0
;
i
<
bound
;
i
++
)
{
InputOffset
=
x
*
bound
*
innerSize
+
i
*
innerSize
+
y
;
auto
inputData
=
input
[
InputOffset
];
idx
=
inputData
>
maxData
?
i
:
idx
;
maxData
=
inputData
>
maxData
?
inputData
:
maxData
;
}
output
[
pos
]
=
maxData
;
index
[
pos
]
=
idx
;
}
return
;
}
template
<
typename
T
,
typename
S
>
void
CalArgmaxWithValue
(
const
T
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
ArgmaxWithValue
<<<
GET_BLOCKS
(
outerSize_
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
bound_
,
outerSize_
,
innerSize_
,
index
,
output
);
void
CalArgmaxWithValue
(
const
T
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
ArgmaxWithValue
<<<
GET_BLOCKS
(
outerSize_
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
bound_
,
outerSize_
,
innerSize_
,
index
,
output
);
return
;
}
template
void
CalArgmaxWithValue
<
float
,
int
>(
const
float
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
*
index
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalArgmaxWithValue
<
half
,
int
>(
const
half
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
*
index
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalArgmaxWithValue
<
float
,
int
>(
const
float
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
*
index
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalArgmaxWithValue
<
half
,
int
>(
const
half
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
*
index
,
half
*
output
,
cudaStream_t
cuda_stream
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录