Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6ea2aa4e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ea2aa4e
编写于
7月 30, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3672 fix serving input numbers
Merge pull request !3672 from hexia/fix_input_check
上级
389cb357
31008247
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
20 deletion
+43
-20
mindspore/ccsrc/backend/session/ascend_inference_session.cc
mindspore/ccsrc/backend/session/ascend_inference_session.cc
+40
-19
mindspore/ccsrc/backend/session/ascend_inference_session.h
mindspore/ccsrc/backend/session/ascend_inference_session.h
+3
-1
未找到文件。
mindspore/ccsrc/backend/session/ascend_inference_session.cc
浏览文件 @
6ea2aa4e
...
...
@@ -94,25 +94,33 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
kernel_graph_inputs
=
kernel_graph
->
inputs
();
size_t
no_weight_input
=
0
;
vector
<
ParameterPtr
>
paras
;
// find parameters of graph inputs
for
(
size_t
i
=
0
;
i
<
kernel_graph_inputs
.
size
();
++
i
)
{
tensor
::
TensorPtr
tensor
=
nullptr
;
if
(
!
kernel_graph_inputs
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
ERROR
)
<<
"Kernel graph inputs have anfnode which is not Parameter."
;
continue
;
}
auto
parameter
=
kernel_graph_inputs
[
i
]
->
cast
<
ParameterPtr
>
();
if
(
!
AnfAlgo
::
IsParameterWeight
(
parameter
))
{
// compare input number
if
(
no_weight_input
>=
inputs
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Input number is inconsistent. The actual input number ["
<<
inputs
.
size
()
<<
"] less than that of graph."
;
return
false
;
}
auto
input
=
inputs
[
no_weight_input
++
];
if
(
!
CompareInput
(
input
,
parameter
))
{
MS_LOG
(
ERROR
)
<<
"Please check the input information."
;
return
false
;
}
paras
.
push_back
(
parameter
);
}
}
// check inputs
for
(
size_t
i
=
0
;
i
<
paras
.
size
();
++
i
)
{
// compare input number
if
(
paras
.
size
()
!=
inputs
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Input number is inconsistent. The actual input number ["
<<
inputs
.
size
()
<<
"] but the graph input number is ["
<<
paras
.
size
()
<<
"]"
;
MS_LOG
(
ERROR
)
<<
"InputsInfo --"
<<
InputsInfo
(
paras
,
inputs
);
return
false
;
}
auto
input
=
inputs
[
no_weight_input
++
];
if
(
!
CompareInput
(
input
,
paras
[
i
]))
{
MS_LOG
(
ERROR
)
<<
"Please check the input information."
;
MS_LOG
(
ERROR
)
<<
"InputsInfo --"
<<
InputsInfo
(
paras
,
inputs
);
return
false
;
}
}
return
true
;
...
...
@@ -123,12 +131,6 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
MS_EXCEPTION_IF_NULL
(
parameter
);
// compare dims
auto
parameter_shape
=
AnfAlgo
::
GetOutputDeviceShape
(
parameter
,
0
);
if
(
input
->
shape
().
size
()
!=
parameter_shape
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Input dim is inconsistent. The actual dim is "
<<
input
->
shape
().
size
()
<<
", but the parameter dim is "
<<
parameter_shape
.
size
()
<<
". parameter : "
<<
parameter
->
DebugString
();
return
false
;
}
// compare shape
auto
input_shape
=
input
->
shape
();
...
...
@@ -153,12 +155,31 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
return
true
;
}
std
::
string
AscendInferenceSession
::
PrintInputShape
(
std
::
vector
<
size_t
>
shape
)
const
{
template
<
typename
T
>
std
::
string
AscendInferenceSession
::
PrintInputShape
(
std
::
vector
<
T
>
shape
)
const
{
string
res
=
"["
;
for
(
auto
dim
:
shape
)
{
res
+=
" "
+
std
::
to_string
(
dim
);
}
return
res
+
" ]"
;
}
std
::
string
AscendInferenceSession
::
InputsInfo
(
const
std
::
vector
<
ParameterPtr
>
&
paras
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
)
const
{
std
::
string
graph
=
"graph inputs:{ "
;
for
(
size_t
i
=
0
;
i
<
paras
.
size
();
++
i
)
{
graph
+=
std
::
to_string
(
i
)
+
": dims "
+
std
::
to_string
(
AnfAlgo
::
GetOutputDeviceShape
(
paras
[
i
],
0
).
size
())
+
", shape "
+
PrintInputShape
(
AnfAlgo
::
GetOutputDeviceShape
(
paras
[
i
],
0
))
+
", data type "
+
std
::
to_string
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
paras
[
i
])
->
GetOutputDeviceType
(
0
))
+
" }"
;
}
std
::
string
actual
=
"actual inputs:{ "
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
actual
+=
std
::
to_string
(
i
)
+
": dims "
+
std
::
to_string
(
inputs
[
i
]
->
shape
().
size
())
+
", shape "
+
PrintInputShape
(
inputs
[
i
]
->
shape
())
+
", data type "
+
std
::
to_string
(
inputs
[
i
]
->
data_type
())
+
" }"
;
}
return
graph
+
" "
+
actual
;
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/backend/session/ascend_inference_session.h
浏览文件 @
6ea2aa4e
...
...
@@ -41,7 +41,9 @@ class AscendInferenceSession : public AscendSession {
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
override
;
bool
CheckModelInputs
(
uint32_t
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
)
const
override
;
bool
CompareInput
(
const
tensor
::
TensorPtr
&
input
,
const
ParameterPtr
&
parameter
)
const
;
std
::
string
PrintInputShape
(
std
::
vector
<
size_t
>
shape
)
const
;
template
<
typename
T
>
std
::
string
PrintInputShape
(
std
::
vector
<
T
>
shape
)
const
;
std
::
string
InputsInfo
(
const
std
::
vector
<
ParameterPtr
>
&
paras
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
)
const
;
};
MS_REG_SESSION
(
kDavinciInferenceDevice
,
AscendInferenceSession
);
}
// namespace session
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录