Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b6209eb8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b6209eb8
编写于
6月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2166 GPU update argmaxwithvalue
Merge pull request !2166 from VectorSL/update-argmaxwithvalue
上级
fc3b0b95
46afb18e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
17 addition
and
30 deletion
+17
-30
mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h
...pore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h
+13
-26
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu
+3
-3
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh
...spore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh
+1
-1
未找到文件。
mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h
浏览文件 @
b6209eb8
...
...
@@ -26,15 +26,7 @@ namespace kernel {
template
<
typename
T
,
typename
S
>
class
ArgmaxWithValueGpuKernel
:
public
GpuKernel
{
public:
ArgmaxWithValueGpuKernel
()
:
input_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
axis_
(
0
),
dims_
(
1
),
bound_
(
0
),
outerSize_
(
0
),
innerSize_
(
0
)
{}
ArgmaxWithValueGpuKernel
()
:
input_size_
(
0
),
output_size_
(
0
),
bound_
(
0
),
outerSize_
(
0
),
innerSize_
(
0
)
{}
~
ArgmaxWithValueGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
...
@@ -46,37 +38,36 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
S
*
index
=
GetDeviceAddress
<
S
>
(
outputs
,
0
);
CalArgmaxWithValue
(
input_size_
/
sizeof
(
T
),
input
,
bound_
,
outerSize_
,
innerSize_
,
axis_
,
dims_
,
index
,
output
,
CalArgmaxWithValue
(
input_size_
/
sizeof
(
T
),
input
,
bound_
,
outerSize_
,
innerSize_
,
index
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
s
hape_
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
s
td
::
vector
<
size_t
>
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
1
);
dims_
=
shape_
.
size
();
axis_
=
GetAttr
<
int
>
(
kernel_node
,
"axis"
);
if
(
axis_
<
0
)
{
axis_
+=
dims_
;
int
dims
=
shape
.
size
();
int
axis
=
GetAttr
<
int
>
(
kernel_node
,
"axis"
);
if
(
axis
<
0
)
{
axis
+=
dims
;
}
input_size_
=
sizeof
(
T
);
for
(
auto
x
:
shape
_
)
{
for
(
auto
x
:
shape
)
{
input_size_
*=
x
;
}
output_size_
=
sizeof
(
S
);
for
(
auto
x
:
output_shape
)
{
output_size_
*=
x
;
}
bound_
=
shape
_
[
axis_
];
bound_
=
shape
[
axis
];
outerSize_
=
1
;
for
(
int
i
=
axis
_
-
1
;
i
>=
0
;
i
--
)
{
outerSize_
*=
shape
_
[
i
];
for
(
int
i
=
axis
-
1
;
i
>=
0
;
i
--
)
{
outerSize_
*=
shape
[
i
];
}
innerSize_
=
1
;
for
(
int
i
=
axis
_
+
1
;
i
<
dims_
;
i
++
)
{
innerSize_
*=
shape
_
[
i
];
for
(
int
i
=
axis
+
1
;
i
<
dims
;
i
++
)
{
innerSize_
*=
shape
[
i
];
}
InitSizeLists
();
return
true
;
...
...
@@ -92,13 +83,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
private:
size_t
input_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
std
::
vector
<
size_t
>
shape_
;
int
axis_
;
int
dims_
;
int
bound_
;
int
outerSize_
;
int
innerSize_
;
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu
浏览文件 @
b6209eb8
...
...
@@ -44,15 +44,15 @@ __global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, in
template
<
typename
T
,
typename
S
>
void
CalArgmaxWithValue
(
size_t
size
,
const
T
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
axis_
,
int
dims_
,
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
ArgmaxWithValue
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input
,
bound_
,
outerSize_
,
innerSize_
,
index
,
output
);
return
;
}
template
void
CalArgmaxWithValue
<
float
,
int
>(
size_t
size
,
const
float
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
axis_
,
int
dims_
,
int
*
index
,
float
*
output
,
const
int
innerSize_
,
int
*
index
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalArgmaxWithValue
<
half
,
int
>(
size_t
size
,
const
half
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
axis_
,
int
dims_
,
int
*
index
,
half
*
output
,
const
int
innerSize_
,
int
*
index
,
half
*
output
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh
浏览文件 @
b6209eb8
...
...
@@ -18,5 +18,5 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
template
<
typename
T
,
typename
S
>
void
CalArgmaxWithValue
(
size_t
size
,
const
T
*
input
,
const
int
bound_
,
const
int
outerSize_
,
const
int
innerSize_
,
int
axis_
,
int
dims_
,
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
);
S
*
index
,
T
*
output
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录