Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
74d12d73
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
74d12d73
编写于
6月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2421 Importing pb model to construct funcgraph
Merge pull request !2421 from yankai10/merge_master_new_fearture
上级
23d0497d
59c66895
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
270 addition
and
108 deletion
+270
-108
include/ms_tensor.h
include/ms_tensor.h
+1
-1
mindspore/ccsrc/ir/anf.h
mindspore/ccsrc/ir/anf.h
+1
-0
mindspore/ccsrc/session/CMakeLists.txt
mindspore/ccsrc/session/CMakeLists.txt
+1
-0
mindspore/ccsrc/session/ascend_inference_session.cc
mindspore/ccsrc/session/ascend_inference_session.cc
+90
-0
mindspore/ccsrc/session/ascend_inference_session.h
mindspore/ccsrc/session/ascend_inference_session.h
+45
-0
mindspore/ccsrc/session/session.cc
mindspore/ccsrc/session/session.cc
+4
-1
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+1
-0
mindspore/ccsrc/utils/base_ref_utils.cc
mindspore/ccsrc/utils/base_ref_utils.cc
+18
-21
mindspore/ccsrc/utils/base_ref_utils.h
mindspore/ccsrc/utils/base_ref_utils.h
+1
-4
mindspore/ccsrc/utils/context/ms_context.h
mindspore/ccsrc/utils/context/ms_context.h
+1
-0
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
+0
-28
mindspore/ccsrc/utils/load_onnx/anf_converter.h
mindspore/ccsrc/utils/load_onnx/anf_converter.h
+0
-1
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
+99
-43
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
+8
-9
未找到文件。
include/ms_tensor.h
浏览文件 @
74d12d73
...
...
@@ -63,7 +63,7 @@ class MS_API MSTensor {
// return A pointer points to data in MSTensor.
virtual
void
*
MutableData
()
const
=
0
;
};
using
MultiTensor
=
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>
>>
;
using
MultiTensor
=
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
;
}
// namespace inference
}
// namespace mindspore
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_
mindspore/ccsrc/ir/anf.h
浏览文件 @
74d12d73
...
...
@@ -217,6 +217,7 @@ class CNode : public AnfNode {
void
set_stop_gradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
}
std
::
string
fullname_with_scope
()
override
;
void
set_fullname_with_scope
(
const
std
::
string
full_name
)
{
fullname_with_scope_
=
full_name
;
}
std
::
string
DebugString
(
int
recursive_level
=
1
)
const
override
;
std
::
string
DebugString
(
bool
recursive
)
const
override
{
return
DebugString
(
recursive
?
1
:
0
);
}
...
...
mindspore/ccsrc/session/CMakeLists.txt
浏览文件 @
74d12d73
...
...
@@ -23,6 +23,7 @@ if (ENABLE_D)
file
(
GLOB_RECURSE _D_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"ascend_session.cc"
"ascend_control_parser.cc"
"ascend_inference_session.cc"
)
list
(
APPEND _SESSION_SRC_LIST
${
_D_SRC_LIST
}
)
endif
()
...
...
mindspore/ccsrc/session/ascend_inference_session.cc
0 → 100644
浏览文件 @
74d12d73
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "session/ascend_inference_session.h"
#include "operator/ops.h"
#include "ir/tensor.h"
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "device/kernel_runtime.h"
#include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
#include "common/trans.h"
#include "kernel/tbe/tbe_python_funcs.h"
#include "utils/config_manager.h"
#include "utils/base_ref_extends.h"
namespace
mindspore
{
namespace
session
{
void
AscendInferenceSession
::
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
tensor
::
TensorPtr
>
inputs
(
inputs_const
);
auto
input_nodes
=
kernel_graph
->
inputs
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
size_t
no_weight_input
=
0
;
for
(
size_t
i
=
0
;
i
<
input_nodes
.
size
();
++
i
)
{
tensor
::
TensorPtr
tensor
=
nullptr
;
if
(
!
input_nodes
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
ERROR
)
<<
"Kernel graph inputs have anfnode which is not Parameter"
;
continue
;
}
auto
pk_node
=
input_nodes
[
i
]
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
pk_node
);
if
(
AnfAlgo
::
IsParameterWeight
(
pk_node
))
{
auto
param_value
=
std
::
dynamic_pointer_cast
<
ParamValuePy
>
(
pk_node
->
default_param
());
MS_EXCEPTION_IF_NULL
(
param_value
);
auto
py_param
=
param_value
->
value
();
MS_EXCEPTION_IF_NULL
(
py_param
);
py
::
array
py_array
=
py_param
.
cast
<
py
::
array
>
();
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py_array
);
}
else
{
tensor
=
inputs
[
no_weight_input
++
];
}
MS_EXCEPTION_IF_NULL
(
tensor
);
if
(
AnfAlgo
::
OutputAddrExist
(
pk_node
,
0
))
{
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
bool
need_sync
=
false
;
if
(
ms_context
->
enable_pynative_infer
())
{
if
(
tensor
->
device_address
().
get
()
==
nullptr
||
tensor
->
device_address
()
!=
device_address
)
{
need_sync
=
true
;
}
}
else
{
if
(
tensor
->
is_dirty
())
{
need_sync
=
true
;
}
else
if
(
tensor
->
device_address
()
!=
device_address
)
{
(
void
)
tensor
->
data_sync
();
need_sync
=
true
;
}
}
if
(
need_sync
)
{
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
||
AnfAlgo
::
IsParameterWeight
(
pk_node
))
{
tensor
->
set_device_address
(
device_address
);
}
MS_EXCEPTION_IF_NULL
(
device_address
);
if
(
!
device_address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
pk_node
,
0
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
(
false
)))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
}
}
tensor
->
set_dirty
(
false
);
}
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_inference_session.h
0 → 100644
浏览文件 @
74d12d73
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
#include <unordered_map>
#include <string>
#include <memory>
#include <vector>
#include <utility>
#include <stack>
#include <map>
#include <tuple>
#include <set>
#include "session/ascend_session.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
#include "session/session_factory.h"
#include "session/ascend_control_parser.h"
namespace
mindspore
{
namespace
session
{
class
AscendInferenceSession
:
public
AscendSession
{
public:
AscendInferenceSession
()
=
default
;
~
AscendInferenceSession
()
=
default
;
void
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
)
const
;
};
MS_REG_SESSION
(
kDavinciInferenceDevice
,
AscendInferenceSession
);
}
// namespace session
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
mindspore/ccsrc/session/session.cc
浏览文件 @
74d12d73
...
...
@@ -124,7 +124,7 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p
});
if
(
has_error
)
{
MS_LOG
(
ERROR
)
<<
"Init Tensor failed, returning empty result"
;
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>
>>
multiTensor
;
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
multiTensor
;
return
multiTensor
;
}
VectorRef
outputs
;
...
...
@@ -135,6 +135,9 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p
int
Session
::
Init
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
RegAllOp
();
auto
ms_context
=
MsContext
::
GetInstance
();
ms_context
->
set_execution_mode
(
kGraphMode
);
ms_context
->
set_device_target
(
kAscendDevice
);
session_impl_
=
session
::
SessionFactory
::
Get
().
Create
(
device
);
if
(
session_impl_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Session create failed!, please make sure target device:"
<<
device
<<
" is available."
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
74d12d73
...
...
@@ -619,6 +619,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
auto
new_cnode
=
CreateNewCNode
(
cnode
,
graph
.
get
());
MS_EXCEPTION_IF_NULL
(
new_cnode
);
new_cnode
->
set_abstract
(
cnode
->
abstract
());
new_cnode
->
set_fullname_with_scope
(
cnode
->
fullname_with_scope
());
new_cnode
->
set_scope
(
cnode
->
scope
());
graph
->
FrontBackendlMapAdd
(
node
,
new_cnode
);
if
(
AnfAlgo
::
CheckPrimitiveType
(
new_cnode
,
prim
::
kPrimReturn
))
{
...
...
mindspore/ccsrc/utils/base_ref_utils.cc
浏览文件 @
74d12d73
...
...
@@ -21,20 +21,27 @@
#include "ir/tensor.h"
namespace
mindspore
{
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
TransformBaseRefToMSTensor
(
const
BaseRef
&
base_ref
)
{
void
IterateFindTensor
(
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
*
msTensors
,
const
VectorRef
&
ref_list
)
{
for
(
size_t
i
=
0
;
i
<
ref_list
.
size
();
++
i
)
{
if
(
utils
::
isa
<
tensor
::
TensorPtr
>
(
ref_list
[
i
]))
{
auto
tensor_ptr
=
utils
::
cast
<
std
::
shared_ptr
<
tensor
::
Tensor
>>
(
ref_list
[
i
]);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
auto
tensor
=
new
inference
::
Tensor
(
tensor_ptr
);
msTensors
->
emplace_back
(
std
::
shared_ptr
<
inference
::
MSTensor
>
(
tensor
));
}
else
if
(
utils
::
isa
<
VectorRef
>
(
ref_list
[
i
]))
{
auto
ref_iter
=
utils
::
cast
<
VectorRef
>
(
ref_list
[
i
]);
IterateFindTensor
(
msTensors
,
ref_iter
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a tensor"
;
}
}
}
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
TransformVectorRefToMultiTensor
(
const
VectorRef
&
base_ref
)
{
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
msTensors
;
if
(
utils
::
isa
<
VectorRef
>
(
base_ref
))
{
auto
ref_list
=
utils
::
cast
<
VectorRef
>
(
base_ref
);
for
(
size_t
i
=
0
;
i
<
ref_list
.
size
();
++
i
)
{
if
(
utils
::
isa
<
tensor
::
Tensor
>
(
ref_list
[
i
]))
{
auto
tensor_ptr
=
utils
::
cast
<
std
::
shared_ptr
<
tensor
::
Tensor
>>
(
ref_list
[
i
]);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
auto
tensor
=
new
inference
::
Tensor
(
tensor_ptr
);
msTensors
.
emplace_back
(
std
::
shared_ptr
<
inference
::
MSTensor
>
(
tensor
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a tensor!"
;
}
}
IterateFindTensor
(
&
msTensors
,
ref_list
);
}
else
if
(
utils
::
isa
<
tensor
::
Tensor
>
(
base_ref
))
{
auto
tensor_ptr
=
utils
::
cast
<
std
::
shared_ptr
<
tensor
::
Tensor
>>
(
base_ref
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
...
...
@@ -45,14 +52,4 @@ std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(con
}
return
msTensors
;
}
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>>
TransformVectorRefToMultiTensor
(
const
VectorRef
&
vector_ref
)
{
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>>
multiTensor
;
for
(
size_t
i
=
0
;
i
<
vector_ref
.
size
();
++
i
)
{
auto
tensors
=
TransformBaseRefToMSTensor
(
vector_ref
[
i
]);
multiTensor
.
emplace_back
(
tensors
);
}
return
multiTensor
;
}
}
// namespace mindspore
mindspore/ccsrc/utils/base_ref_utils.h
浏览文件 @
74d12d73
...
...
@@ -22,9 +22,6 @@
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
namespace
mindspore
{
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
TransformBaseRefToMSTensor
(
const
BaseRef
&
base_ref
);
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>>
TransformVectorRefToMultiTensor
(
const
VectorRef
&
vector_ref
);
std
::
vector
<
std
::
shared_ptr
<
inference
::
MSTensor
>>
TransformVectorRefToMultiTensor
(
const
VectorRef
&
base_ref
);
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
mindspore/ccsrc/utils/context/ms_context.h
浏览文件 @
74d12d73
...
...
@@ -41,6 +41,7 @@ const int kPynativeMode = 1;
const
char
kCPUDevice
[]
=
"CPU"
;
const
char
kGPUDevice
[]
=
"GPU"
;
const
char
kAscendDevice
[]
=
"Ascend"
;
const
char
kDavinciInferenceDevice
[]
=
"AscendInference"
;
const
char
kDavinciDevice
[]
=
"Davinci"
;
const
char
KNpuLog
[]
=
"_npu_log"
;
const
std
::
set
<
std
::
string
>
kTargetSet
=
{
kCPUDevice
,
kGPUDevice
,
kAscendDevice
,
kDavinciDevice
};
...
...
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
浏览文件 @
74d12d73
...
...
@@ -96,8 +96,6 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file
ReadOnnxFromBinary
(
modelFile
,
&
model_
);
MSANFModelParser
model_parser
;
FuncGraphPtr
dstgraph_ptr
=
model_parser
.
Parse
(
model_
);
MS_EXCEPTION_IF_NULL
(
dstgraph_ptr
);
TestFuncGraphBuild
(
dstgraph_ptr
);
return
dstgraph_ptr
;
}
...
...
@@ -111,33 +109,7 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const
}
MSANFModelParser
model_parser
;
FuncGraphPtr
dstgraph_ptr
=
model_parser
.
Parse
(
model_
);
MS_EXCEPTION_IF_NULL
(
dstgraph_ptr
);
TestFuncGraphBuild
(
dstgraph_ptr
);
return
dstgraph_ptr
;
}
int
AnfConverter
::
TestFuncGraphBuild
(
const
FuncGraphPtr
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
auto
node_return
=
graph
->
get_return
();
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
node_return
);
MS_LOG
(
INFO
)
<<
"node_list size is : "
<<
node_list
.
size
();
for
(
auto
&
node
:
node_list
)
{
if
(
node
->
isa
<
CNode
>
())
{
auto
node_CN
=
node
->
cast
<
CNodePtr
>
();
MS_LOG
(
INFO
)
<<
"CN node: "
<<
node_CN
->
input
(
0
)
->
ToString
()
<<
", input size :"
<<
node_CN
->
size
();
}
else
if
(
node
->
isa
<
Parameter
>
())
{
auto
node_Para
=
node
->
cast
<
ParameterPtr
>
();
if
(
node_Para
->
has_default
())
{
MS_LOG
(
INFO
)
<<
"Parameter node: "
<<
node_Para
->
name
()
<<
"has default value!"
;
}
else
{
MS_LOG
(
INFO
)
<<
"Parameter node: "
<<
node_Para
->
name
();
}
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
auto
node_Value
=
node
->
cast
<
ValueNodePtr
>
();
MS_LOG
(
INFO
)
<<
"Value node: "
<<
node_Value
->
ToString
();
}
}
return
0
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/ccsrc/utils/load_onnx/anf_converter.h
浏览文件 @
74d12d73
...
...
@@ -26,7 +26,6 @@ namespace mindspore {
namespace
lite
{
class
AnfConverter
{
public:
static
int
TestFuncGraphBuild
(
const
FuncGraphPtr
&
graph
);
static
std
::
shared_ptr
<
FuncGraph
>
RunAnfConverter
(
const
std
::
string
&
file_path
);
static
std
::
shared_ptr
<
FuncGraph
>
RunAnfConverter
(
const
char
*
buf
,
const
size_t
buf_size
);
...
...
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
浏览文件 @
74d12d73
...
...
@@ -14,16 +14,17 @@
* limitations under the License.
*/
#include "utils/load_onnx/anf_model_parser.h"
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "utils/load_onnx/anf_model_parser.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "proto/onnx.pb.h"
#include "utils/log_adapter.h"
...
...
@@ -33,6 +34,8 @@ namespace mindspore {
namespace
lite
{
static
constexpr
char
kConstantValueNode
[]
=
"Constant"
;
static
constexpr
char
kCNodeShapeAttr
[]
=
"shape"
;
static
constexpr
char
kCNodeShape1Attr
[]
=
"shape1"
;
static
constexpr
char
kCNodeShape2Attr
[]
=
"shape2"
;
enum
ParseForm
:
int
{
FORM_PARSE_TYPE
=
0
,
FORM_PARSE_SCALAR
=
1
,
...
...
@@ -56,14 +59,15 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<
valuetype> attr_value_vec;
\
std::vector<
ValuePtr> attr_value_vec;
\
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
attr_value_vec.push_back(static_cast<valuetype>(attr_tensor.type##_data(i))); \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name,
MakeValue<valuetype>(attr_value_vec[0]));
\
prim->AddAttr(attr_name,
attr_value_vec[0]);
\
} else { \
prim->AddAttr(attr_name,
MakeValue<std::vector<valuetype>>(attr_value_vec));
\
prim->AddAttr(attr_name,
std::make_shared<ValueList>(attr_value_vec));
\
} \
}
...
...
@@ -247,17 +251,12 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
const
std
::
string
&
tensor_buf
=
attr_tensor
.
raw_data
();
auto
*
tensor_data_buf
=
reinterpret_cast
<
uint8_t
*>
(
tensor_info
->
data_c
(
true
));
memcpy_s
(
tensor_data_buf
,
tensor_info
->
data
().
nbytes
(),
tensor_buf
.
data
(),
tensor_buf
.
size
());
if
(
attr_tensor_type
==
onnx
::
TensorProto_DataType_FLOAT
)
{
auto
*
data_valuennode
=
reinterpret_cast
<
float
*>
(
tensor_info
->
data_c
());
MS_EXCEPTION_IF_NULL
(
data_valuennode
);
auto
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
MakeValue
(
*
data_valuennode
));
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
}
else
{
auto
*
data_valuenode
=
reinterpret_cast
<
int32
*>
(
tensor_info
->
data_c
());
MS_EXCEPTION_IF_NULL
(
data_valuenode
);
auto
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
MakeValue
(
*
data_valuenode
));
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
}
auto
new_value_node
=
NewValueNode
(
MakeValue
(
tensor_info
));
MS_EXCEPTION_IF_NULL
(
new_value_node
);
auto
tensor_abstract
=
tensor_info
->
ToAbstract
();
MS_EXCEPTION_IF_NULL
(
tensor_abstract
);
new_value_node
->
set_abstract
(
tensor_abstract
);
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
return
true
;
}
...
...
@@ -315,7 +314,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
MS_LOG
(
ERROR
)
<<
"Obtain ValueNode attr in type-form has not support input type: "
<<
attr_tensor_type
;
return
false
;
}
auto
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor_type
]));
auto
new_value_node
=
NewValueNode
(
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor_type
]));
abstract
::
AbstractTypePtr
abs_type
=
std
::
make_shared
<
abstract
::
AbstractType
>
(
std
::
make_shared
<
TypeType
>
());
new_value_node
->
set_abstract
(
abs_type
);
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
return
true
;
}
...
...
@@ -361,31 +362,45 @@ AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
attr_tensor
.
data_type
()],
shape_vec
);
MS_EXCEPTION_IF_NULL
(
tensor_info
);
return
tensor_info
->
ToAbstract
();
auto
abstract
=
tensor_info
->
ToAbstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
return
abstract
;
}
bool
MSANFModelParser
::
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
,
const
onnx
::
GraphProto
&
importProto
,
const
bool
&
ret_flag
)
{
CNodePtr
MSANFModelParser
::
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
if
(
!
node_proto
.
has_op_type
())
{
MS_LOG
(
ERROR
)
<<
"Get CNode op_type failed!"
;
return
false
;
return
nullptr
;
}
const
std
::
string
&
node_name
=
node_proto
.
output
(
0
);
const
std
::
string
&
fullname_with_scope
=
node_proto
.
domain
();
const
std
::
string
&
node_type
=
node_proto
.
op_type
();
PrimitivePtr
prim
=
std
::
make_shared
<
Primitive
>
(
node_type
);
MS_EXCEPTION_IF_NULL
(
prim
);
prim
->
set_instance_name
(
node_type
);
AbstractBasePtr
abstract
;
AbstractBasePtr
abstract
=
nullptr
;
AbstractBasePtr
abstract_first
=
nullptr
;
AbstractBasePtr
abstract_second
=
nullptr
;
for
(
int
i
=
0
;
i
<
node_proto
.
attribute_size
();
++
i
)
{
const
onnx
::
AttributeProto
&
attr_proto
=
node_proto
.
attribute
(
i
);
if
(
attr_proto
.
name
()
==
kCNodeShapeAttr
)
{
abstract
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape1Attr
)
{
abstract_first
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape2Attr
)
{
abstract_second
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
!
GetAttrValueForCNode
(
prim
,
attr_proto
))
{
MS_LOG
(
ERROR
)
<<
"Get CNode attr failed!"
;
return
false
;
return
nullptr
;
}
}
...
...
@@ -396,16 +411,64 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap
const
std
::
string
&
input_name
=
node_proto
.
input
(
i
);
if
(
anfnode_build_map_
.
find
(
input_name
)
==
anfnode_build_map_
.
end
())
{
MS_LOG
(
ERROR
)
<<
node_name
<<
" input "
<<
i
<<
input_name
<<
"can't find in nodes have parsed"
;
return
false
;
return
nullptr
;
}
inputs
.
push_back
(
anfnode_build_map_
[
input_name
]);
}
CNodePtr
cnode_ptr
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
cnode_ptr
->
set_abstract
(
abstract
);
if
(
ret_flag
)
{
if
(
node_type
==
"LayerNorm"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
elem
.
push_back
(
abstract_second
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
node_type
==
"ArgMaxWithValue"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
nullptr
==
abstract
)
{
AbstractBasePtrList
elem
;
for
(
size_t
index
=
1
;
index
<
cnode_ptr
->
inputs
().
size
();
++
index
)
{
elem
.
push_back
(
cnode_ptr
->
input
(
index
)
->
abstract
());
}
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
{
cnode_ptr
->
set_abstract
(
abstract
);
}
cnode_ptr
->
set_fullname_with_scope
(
fullname_with_scope
);
anfnode_build_map_
[
node_name
]
=
cnode_ptr
;
return
cnode_ptr
;
}
bool
MSANFModelParser
::
BuildReturnForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
,
const
CNodePtr
&
cnode_ptr
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
std
::
vector
<
AnfNodePtr
>
inputs
;
if
(
importProto
.
output_size
()
>
1
)
{
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
AbstractBasePtrList
elem
;
for
(
int
out_size
=
0
;
out_size
<
importProto
.
output_size
();
++
out_size
)
{
const
onnx
::
ValueInfoProto
&
output_node
=
importProto
.
output
(
out_size
);
const
std
::
string
&
out_tuple
=
output_node
.
name
();
inputs
.
push_back
(
anfnode_build_map_
[
out_tuple
]);
elem
.
push_back
(
anfnode_build_map_
[
out_tuple
]
->
abstract
());
}
auto
maketuple_ptr
=
outputFuncGraph
->
NewCNode
(
inputs
);
maketuple_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
inputs
.
push_back
(
maketuple_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
outputFuncGraph
->
set_return
(
return_node
);
MS_LOG
(
INFO
)
<<
"Construct funcgraph finined, all success."
;
}
else
{
const
onnx
::
ValueInfoProto
&
output_node
=
importProto
.
output
(
0
);
const
::
onnx
::
TypeProto
&
output_typeproto
=
output_node
.
type
();
const
onnx
::
TypeProto
&
output_typeproto
=
output_node
.
type
();
int
output_type
=
output_typeproto
.
tensor_type
().
elem_type
();
std
::
vector
<
int
>
output_shape
;
for
(
int
i
=
0
;
i
<
output_typeproto
.
tensor_type
().
shape
().
dim_size
();
++
i
)
{
...
...
@@ -417,20 +480,19 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
inputs
.
push_back
(
cnode_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
return_node
->
set_abstract
(
tensor_return
->
ToAbstract
());
outputFuncGraph
->
set_return
(
return_node
);
MS_LOG
(
INFO
)
<<
"Construct funcgraph finined, all success!"
;
}
anfnode_build_map_
[
node_name
]
=
cnode_ptr
;
return
true
;
}
bool
MSANFModelParser
::
ImportNodesForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
bool
return_flag
=
false
;
MS_LOG
(
INFO
)
<<
"The CNdoe size : "
<<
importProto
.
node_size
();
CNodePtr
cnode_ptr
=
nullptr
;
for
(
int
i
=
0
;
i
<
importProto
.
node_size
();
++
i
)
{
return_flag
=
(
i
==
importProto
.
node_size
()
-
1
)
?
true
:
return_flag
;
const
onnx
::
NodeProto
&
node_proto
=
importProto
.
node
(
i
);
const
std
::
string
&
node_type
=
node_proto
.
op_type
();
if
(
node_type
==
kConstantValueNode
)
{
...
...
@@ -440,11 +502,14 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
}
continue
;
}
if
(
!
BuildCNodeForFuncGraph
(
outputFuncGraph
,
node_proto
,
importProto
,
return_flag
))
{
cnode_ptr
=
BuildCNodeForFuncGraph
(
outputFuncGraph
,
node_proto
);
if
(
cnode_ptr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build CNode for funcgraph fail at index: : "
<<
i
;
return
false
;
}
}
BuildReturnForFuncGraph
(
outputFuncGraph
,
importProto
,
cnode_ptr
);
return
true
;
}
...
...
@@ -472,12 +537,12 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode
producer_name_
=
model_proto
.
producer_name
();
MS_LOG
(
INFO
)
<<
"producer_name :"
<<
producer_name_
;
if
(
!
model_proto
.
has_
producer
_version
())
{
if
(
!
model_proto
.
has_
model
_version
())
{
MS_LOG
(
ERROR
)
<<
"Parse model producer version from pb file failed!"
;
return
false
;
}
producer_version_
=
model_proto
.
producer
_version
();
MS_LOG
(
INFO
)
<<
"producer_version : "
<<
producer
_version_
;
model_version_
=
model_proto
.
model
_version
();
MS_LOG
(
INFO
)
<<
"producer_version : "
<<
model
_version_
;
if
(
!
model_proto
.
has_ir_version
())
{
MS_LOG
(
ERROR
)
<<
"Parse model version from pb file failed!"
;
...
...
@@ -485,14 +550,6 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode
}
ir_version_
=
model_proto
.
ir_version
();
MS_LOG
(
INFO
)
<<
"ir_version :"
<<
ir_version_
;
const
onnx
::
OperatorSetIdProto
&
opset_proto
=
model_proto
.
opset_import
(
0
);
if
(
!
opset_proto
.
has_version
())
{
MS_LOG
(
ERROR
)
<<
"Parse opset version from pb file failed!"
;
return
false
;
}
opset_version_
=
opset_proto
.
version
();
MS_LOG
(
INFO
)
<<
"opset_version : "
<<
opset_version_
;
return
true
;
}
...
...
@@ -501,7 +558,6 @@ FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) {
MS_EXCEPTION_IF_NULL
(
dstGraph
);
if
(
!
MSANFParseModelConfigureInfo
(
model_proto
))
{
MS_LOG
(
ERROR
)
<<
"Parse configuration info for pb file failed!"
;
return
nullptr
;
}
const
onnx
::
GraphProto
&
graphBuild
=
model_proto
.
graph
();
if
(
!
BuildFuncGraph
(
dstGraph
,
graphBuild
))
{
...
...
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
浏览文件 @
74d12d73
...
...
@@ -29,6 +29,7 @@ namespace lite {
using
int32
=
int32_t
;
using
int64
=
int64_t
;
using
uint64
=
uint64_t
;
using
float16
=
Eigen
::
half
;
class
MSANFModelParser
{
public:
MSANFModelParser
()
=
default
;
...
...
@@ -38,17 +39,17 @@ class MSANFModelParser {
bool
MSANFParseModelConfigureInfo
(
const
onnx
::
ModelProto
&
model_proto
);
std
::
string
GetProducerName
()
{
return
producer_name_
;
}
std
::
string
GetProducerVersion
()
{
return
producer
_version_
;
}
int
GetProducerVersion
()
{
return
model
_version_
;
}
int
GetIrVersion
()
{
return
ir_version_
;
}
int
GetOpsetVersion
()
{
return
opset_version_
;
}
private:
bool
BuildFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
bool
ImportParametersForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
bool
ImportNodesForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
bool
BuildParameterForFuncGraph
(
const
ParameterPtr
&
node
,
const
onnx
::
ValueInfoProto
&
value_proto
);
bool
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
,
const
onnx
::
GraphProto
&
importProto
,
const
bool
&
ret_flag
);
CNodePtr
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
);
bool
BuildReturnForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
,
const
CNodePtr
&
cnode_ptr
);
bool
GetAttrValueForCNode
(
const
PrimitivePtr
&
prim
,
const
onnx
::
AttributeProto
&
attr_proto
);
bool
ObtainCNodeAttrInTypeForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
...
...
@@ -63,15 +64,13 @@ class MSANFModelParser {
bool
GetAttrValueForValueNode
(
const
string
&
ref_attr_name
,
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainValueNodeInTypeForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
AbstractBasePtr
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
);
std
::
string
producer_name_
;
std
::
string
producer_version_
;
int
ir_version_
{};
int
opset_version_
{};
int
model_version_
;
int
ir_version_
;
std
::
unordered_map
<
std
::
string
,
AnfNodePtr
>
anfnode_build_map_
;
std
::
map
<
std
::
string
,
onnx
::
TensorProto
>
default_para_map_
;
AbstractBasePtr
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
);
};
}
// namespace lite
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录