Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
11b3fa4b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11b3fa4b
编写于
8月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5349 gpu GoogleNet performance optimize
Merge pull request !5349 from limingqi107/master
上级
52a7db81
ff6b64a5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
52 addition
and
9 deletion
+52
-9
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
.../backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
+6
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
+26
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
.../ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
+11
-4
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
+6
-0
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h
+3
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
浏览文件 @
11b3fa4b
...
...
@@ -63,19 +63,21 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
if
(
!
CheckParam
(
kernel_node
))
{
return
false
;
}
axis_
=
GetAttr
<
int
>
(
kernel_node
,
"axis"
);
if
(
axis_
<
0
)
{
auto
input_shape
=
AnfAlgo
::
Get
PrevNodeOutputInfer
Shape
(
kernel_node
,
0
);
auto
input_shape
=
AnfAlgo
::
Get
InputDevice
Shape
(
kernel_node
,
0
);
axis_
+=
SizeToInt
(
input_shape
.
size
());
}
auto
origin_data_format
=
AnfAlgo
::
GetOriginDataFormat
(
kernel_node
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
kernel_node
,
0
);
axis_
=
AxisTransform
(
origin_data_format
,
input_format
,
axis_
);
input_num_
=
SizeToInt
(
AnfAlgo
::
GetInputTensorNum
(
kernel_node
));
inputs_host_
=
std
::
make_unique
<
T
*
[]
>
(
input_num_
);
len_axis_
=
std
::
make_unique
<
int
[]
>
(
input_num_
);
for
(
int
i
=
0
;
i
<
input_num_
;
i
++
)
{
size_t
input_size
=
1
;
auto
input_shape
=
AnfAlgo
::
Get
PrevNodeOutputInfer
Shape
(
kernel_node
,
i
);
auto
input_shape
=
AnfAlgo
::
Get
InputDevice
Shape
(
kernel_node
,
i
);
for
(
size_t
j
=
0
;
j
<
input_shape
.
size
();
j
++
)
{
input_size
*=
input_shape
[
j
];
}
...
...
@@ -85,7 +87,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
workspace_size_list_
.
push_back
(
sizeof
(
T
*
)
*
input_num_
);
workspace_size_list_
.
push_back
(
sizeof
(
int
)
*
input_num_
);
auto
output_shape
=
AnfAlgo
::
GetOutput
Infer
Shape
(
kernel_node
,
0
);
auto
output_shape
=
AnfAlgo
::
GetOutput
Device
Shape
(
kernel_node
,
0
);
output_size_
=
1
;
for
(
int
i
=
0
;
i
<
SizeToInt
(
output_shape
.
size
());
i
++
)
{
output_size_
*=
output_shape
[
i
];
...
...
@@ -98,7 +100,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
}
}
output_size_list_
.
push_back
(
output_size_
*
sizeof
(
T
));
InitSizeLists
();
return
true
;
}
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
浏览文件 @
11b3fa4b
...
...
@@ -22,6 +22,7 @@
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "runtime/device/gpu/gpu_device_manager.h"
...
...
@@ -31,6 +32,19 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
namespace
mindspore
{
namespace
kernel
{
static
std
::
map
<
int
,
int
>
kNCHWToNHWCAxisMap
=
{
{
0
,
0
},
{
1
,
3
},
{
2
,
1
},
{
3
,
2
},
};
static
std
::
map
<
int
,
int
>
kNHWCToNCHWAxisMap
=
{
{
0
,
0
},
{
1
,
2
},
{
2
,
3
},
{
3
,
1
},
};
class
GpuKernel
:
public
KernelMod
{
public:
virtual
~
GpuKernel
()
=
default
;
...
...
@@ -74,6 +88,18 @@ class GpuKernel : public KernelMod {
dst
->
push_back
(
src
.
size
()
==
0
?
1
:
SizeToInt
(
src
[
src
.
size
()
-
1
]));
}
int
AxisTransform
(
const
std
::
string
&
origin_data_format
,
const
std
::
string
&
cal_format
,
int
axis
)
{
if
(((
origin_data_format
==
kOpFormat_DEFAULT
)
||
(
origin_data_format
==
kOpFormat_NCHW
))
&&
(
cal_format
==
kOpFormat_NHWC
))
{
return
kNCHWToNHWCAxisMap
[
axis
];
}
else
if
(((
cal_format
==
kOpFormat_DEFAULT
)
||
(
cal_format
==
kOpFormat_NCHW
))
&&
(
origin_data_format
==
kOpFormat_NHWC
))
{
return
kNHWCToNCHWAxisMap
[
axis
];
}
else
{
return
axis
;
}
}
// transpose shape: NCHW To NHWC
void
ShapeNCHW2NHWC
(
std
::
vector
<
size_t
>
*
shape
)
{
std
::
swap
((
*
shape
)[
1
],
(
*
shape
)[
3
]);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
浏览文件 @
11b3fa4b
...
...
@@ -82,7 +82,7 @@ class AddNGpuFwdKernel : public GpuKernel {
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but cudnnAddTensor needs 1 output."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
Get
PrevNodeOutputInfer
Shape
(
kernel_node
,
0
);
auto
input_shape
=
AnfAlgo
::
Get
InputDevice
Shape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
MS_LOG
(
WARNING
)
<<
"AddNGpuFwdKernel input is null"
;
...
...
@@ -96,9 +96,16 @@ class AddNGpuFwdKernel : public GpuKernel {
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
dimA
[
i
]
=
SizeToInt
(
input_shape
[
i
]);
}
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensorNdDescriptorEx
(
input_descriptor_
,
CUDNN_TENSOR_NCHW
,
cudnn_data_type_
,
SizeToInt
(
input_shape
.
size
()),
dimA
),
"cudnnSetTensorNdDescriptor failed"
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
kernel_node
,
0
);
if
(
input_format
==
kOpFormat_NHWC
)
{
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensorNdDescriptorEx
(
input_descriptor_
,
CUDNN_TENSOR_NHWC
,
cudnn_data_type_
,
SizeToInt
(
input_shape
.
size
()),
dimA
),
"cudnnSetTensorNdDescriptor failed"
);
}
else
{
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensorNdDescriptorEx
(
input_descriptor_
,
CUDNN_TENSOR_NCHW
,
cudnn_data_type_
,
SizeToInt
(
input_shape
.
size
()),
dimA
),
"cudnnSetTensorNdDescriptor failed"
);
}
InitSizeLists
();
return
true
;
}
...
...
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
浏览文件 @
11b3fa4b
...
...
@@ -194,6 +194,12 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
auto
cal_format
=
(
inputs_type
[
0
]
==
kNumberTypeFloat16
)
?
kOpFormat_NHWC
:
kOpFormat_NCHW
;
MS_LOG
(
DEBUG
)
<<
"Kernel node: "
<<
kernel_node
->
fullname_with_scope
()
<<
", format: "
<<
cal_format
;
auto
inputs_format_position
=
iter
->
second
.
first
;
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
if
(
inputs_format_position
.
size
()
==
0
)
{
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
input_index
++
)
{
inputs_format_position
.
push_back
(
input_index
);
}
}
for
(
const
auto
&
input_format_position
:
inputs_format_position
)
{
if
(
input_format_position
>=
inputs_format
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The position ["
<<
input_format_position
<<
"] is out of range of the input size ["
...
...
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h
浏览文件 @
11b3fa4b
...
...
@@ -30,6 +30,7 @@ namespace mindspore {
namespace
device
{
namespace
gpu
{
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
static
std
::
map
<
std
::
string
,
std
::
pair
<
std
::
vector
<
size_t
>
,
std
::
vector
<
size_t
>>>
kKernelFormatPositionMap
=
{
{
prim
::
kPrimConv2D
->
name
(),
{{
0
,
1
},
{
0
}}},
{
prim
::
kPrimConv2DBackpropInput
->
name
(),
{{
0
,
1
},
{
0
}}},
...
...
@@ -47,6 +48,8 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{
kFusedBatchNormGradEx
,
{{
0
,
1
},
{
0
}}},
{
kFusedBatchNormGradExWithActivation
,
{{
0
,
1
,
7
},
{
0
}}},
{
kFusedBatchNormGradExWithAddAndActivation
,
{{
0
,
1
,
7
},
{
0
,
3
}}},
{
prim
::
kPrimConcat
->
name
(),
{{},
{
0
}}},
{
prim
::
kPrimAddN
->
name
(),
{{},
{
0
}}},
};
void
SetKernelInfo
(
const
CNodePtr
&
apply_kernel_ptr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录