Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
04fdeddf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04fdeddf
编写于
5月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1299 cpu kernel support multiple dtype
Merge pull request !1299 from sunsuodong/ops_int32
上级
eed1c343
df7281c3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
28 addition
and
12 deletion
+28
-12
mindspore/ccsrc/device/cpu/kernel_select_cpu.cc
mindspore/ccsrc/device/cpu/kernel_select_cpu.cc
+15
-9
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
+5
-3
mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h
mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h
+8
-0
未找到文件。
mindspore/ccsrc/device/cpu/kernel_select_cpu.cc
浏览文件 @
04fdeddf
...
...
@@ -59,6 +59,7 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
TypeId
dtype
=
kTypeUnknown
;
if
(
IsInputNotCNode
(
kernel_node
,
input_index
))
{
input_no_cnode_indexes
->
emplace_back
(
input_index
);
dtype
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
kernel_node
,
input_index
);
}
else
{
dtype
=
AnfAlgo
::
GetPrevNodeOutputDeviceDataType
(
kernel_node
,
input_index
);
}
...
...
@@ -84,22 +85,25 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
const
std
::
vector
<
TypeId
>
&
input_types
,
const
std
::
vector
<
size_t
>
&
input_not_cnode_indexes
)
{
if
(
kernel_attr
.
GetInputSize
()
!=
input_types
.
size
())
{
MS_LOG
(
ERROR
)
<<
"
Output num is not equal!"
;
MS_LOG
(
ERROR
)
<<
"
required input num:"
<<
kernel_attr
.
GetInputSize
()
<<
", actual input num:"
<<
input_types
.
size
()
;
return
false
;
}
auto
input_num
=
input_types
.
size
();
for
(
size_t
i
=
0
;
i
<
input_num
;
++
i
)
{
bool
is_not_cnode_idx
=
std
::
any_of
(
input_not_cnode_indexes
.
begin
(),
input_not_cnode_indexes
.
end
(),
[
i
](
size_t
index
)
{
return
index
==
i
;
});
if
(
is_not_cnode_idx
)
{
bool
have_cnode_input
=
(
input_types
.
size
()
!=
input_not_cnode_indexes
.
size
());
if
(
have_cnode_input
&&
is_not_cnode_idx
)
{
continue
;
}
if
(
kernel_attr
.
GetInputAttr
(
i
).
first
!=
input_types
[
i
])
{
MS_LOG
(
ERROR
)
<<
"reg dtype="
<<
kernel_attr
.
GetInputAttr
(
i
).
first
<<
", input dtype="
<<
input_types
[
i
];
MS_LOG
(
DEBUG
)
<<
"required dtype:"
<<
kernel_attr
.
GetInputAttr
(
i
).
first
<<
", actual input dtype:"
<<
input_types
[
i
];
return
false
;
}
if
(
kernel_attr
.
GetInputAttr
(
i
).
second
!=
input_formats
[
i
])
{
MS_LOG
(
ERROR
)
<<
"reg format="
<<
kernel_attr
.
GetInputAttr
(
i
).
second
<<
", input format="
<<
input_formats
[
i
];
MS_LOG
(
DEBUG
)
<<
"required format:"
<<
kernel_attr
.
GetInputAttr
(
i
).
second
<<
", actual input format:"
<<
input_formats
[
i
];
return
false
;
}
}
...
...
@@ -114,17 +118,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std
::
vector
<
std
::
string
>
output_formats
;
std
::
vector
<
TypeId
>
output_types
;
MS_LOG
(
INFO
)
<<
"SetKernelInfo, CNode Name: "
<<
AnfAlgo
::
GetCNodeName
(
kernel_node
);
GetInputFormatsAndDtypes
(
kernel_node
,
&
input_formats
,
&
input_types
,
&
input_not_cnode_indexes
);
auto
kernel_attrs
=
kernel
::
CPUKernelFactory
::
GetInstance
().
GetSupportedKernelAttrList
(
AnfAlgo
::
GetCNodeName
(
kernel_node
));
for
(
auto
&
kernel_attr
:
kernel_attrs
)
{
if
(
IsInputFormatDtypeMatched
(
kernel_attr
,
input_formats
,
input_types
,
input_not_cnode_indexes
))
{
GetOutputFormatsAndDtypes
(
kernel_node
,
kernel_attr
,
&
output_formats
,
&
output_types
);
UpdatePrevNotCNodeFormatDtype
(
kernel_attr
,
input_not_cnode_indexes
,
kernel_node
);
for
(
size_t
index
=
0
;
index
<
kernel_attrs
.
size
();
++
index
)
{
if
(
IsInputFormatDtypeMatched
(
kernel_attrs
[
index
],
input_formats
,
input_types
,
input_not_cnode_indexes
))
{
MS_LOG
(
INFO
)
<<
"Input format and dtype is matched, index: "
<<
index
;
GetOutputFormatsAndDtypes
(
kernel_node
,
kernel_attrs
[
index
],
&
output_formats
,
&
output_types
);
UpdatePrevNotCNodeFormatDtype
(
kernel_attrs
[
index
],
input_not_cnode_indexes
,
kernel_node
);
for
(
auto
&
input_index
:
input_not_cnode_indexes
)
{
input_types
[
input_index
]
=
kernel_attr
.
GetInputAttr
(
input_index
).
first
;
input_types
[
input_index
]
=
kernel_attr
s
[
index
]
.
GetInputAttr
(
input_index
).
first
;
}
break
;
}
...
...
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
浏览文件 @
04fdeddf
...
...
@@ -55,10 +55,12 @@ class CPUKernelRegistrar {
~
CPUKernelRegistrar
()
=
default
;
};
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS)
#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS)
#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##
OPNAME##_reg(#OPNAME, ATTR,
\
[]() { return std::make_shared<OPCLASS>(); });
static const CPUKernelRegistrar g_cpu_kernel_##
COUNT##_reg(#OPNAME, ATTR,
\
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
...
...
mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h
浏览文件 @
04fdeddf
...
...
@@ -35,10 +35,18 @@ class ReshapeCPUKernel : public CPUKernel {
MS_REG_CPU_KERNEL
(
Reshape
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ReshapeCPUKernel
);
MS_REG_CPU_KERNEL
(
Reshape
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ReshapeCPUKernel
);
MS_REG_CPU_KERNEL
(
Flatten
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ReshapeCPUKernel
);
MS_REG_CPU_KERNEL
(
Flatten
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ReshapeCPUKernel
);
MS_REG_CPU_KERNEL
(
ExpandDims
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ReshapeCPUKernel
);
MS_REG_CPU_KERNEL
(
ExpandDims
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ReshapeCPUKernel
);
}
// namespace kernel
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录