Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
04383c66
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04383c66
编写于
7月 29, 2020
作者:
Y
yankai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix import
上级
387dac58
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
636 addition
and
42 deletion
+636
-42
mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc
...ommon/anf_exporter/anf_populater/anf_reshape_populater.cc
+35
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h
...common/anf_exporter/anf_populater/anf_reshape_populater.h
+30
-0
mindspore/lite/src/common/anf_importer/import_from_protobuf.cc
...pore/lite/src/common/anf_importer/import_from_protobuf.cc
+546
-42
mindspore/lite/src/common/anf_importer/import_from_protobuf.h
...spore/lite/src/common/anf_importer/import_from_protobuf.h
+25
-0
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc
0 → 100644
浏览文件 @
04383c66
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace
mindspore
::
lite
{
int
mindspore
::
lite
::
AnfReshapePopulater
::
Parse
(
mindspore
::
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
FlattenT
>
();
node
->
nodeType
=
schema
::
NodeType_CNode
;
node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Flatten
;
node
->
primitive
->
value
.
value
=
attr
.
release
();
return
0
;
}
AnfNodePopulaterRegistrar
anfReshapeParser
(
"Reshape"
,
new
AnfReshapePopulater
());
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h
0 → 100644
浏览文件 @
04383c66
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H
#define MINDSPORE_ANF_RESHAPE_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace
mindspore
::
lite
{
class
AnfReshapePopulater
:
public
AnfNodePopulater
{
public:
AnfReshapePopulater
()
=
default
;
~
AnfReshapePopulater
()
override
=
default
;
int
Parse
(
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
override
;
};
}
// namespace mindspore::lite
#endif // MINDSPORE_ANF_RESHAPE_PARSER_H
mindspore/lite/src/common/anf_importer/import_from_protobuf.cc
浏览文件 @
04383c66
...
...
@@ -15,25 +15,28 @@
*/
#include "src/common/anf_importer/import_from_protobuf.h"
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <functional>
#include <map>
#include <stack>
#include <unordered_map>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
#include <fstream>
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "src/param_value_lite.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/ir/tensor.h"
#include "
frontend/operator/ops
.h"
#include "
src/param_value_lite
.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
using
string
=
std
::
string
;
using
int32
=
int32_t
;
...
...
@@ -54,26 +57,27 @@ enum ParseForm : int {
};
static
std
::
map
<
std
::
string
,
ParseForm
>
kParseTypeSwitchMap
{
{
"type"
,
FORM_PARSE_TYPE
},
{
"scalar"
,
FORM_PARSE_SCALAR
},
{
"tensor"
,
FORM_PARSE_TENSOR
}};
{
"type"
,
FORM_PARSE_TYPE
},
{
"scalar"
,
FORM_PARSE_SCALAR
},
{
"tensor"
,
FORM_PARSE_TENSOR
}};
static
std
::
unordered_map
<
int
,
TypeId
>
kDefaultValueSwitchMap
{
{
onnx
::
TensorProto_DataType_BOOL
,
kNumberTypeBool
},
{
onnx
::
TensorProto_DataType_INT8
,
kNumberTypeInt8
},
{
onnx
::
TensorProto_DataType_INT16
,
kNumberTypeInt16
},
{
onnx
::
TensorProto_DataType_INT32
,
kNumberTypeInt32
},
{
onnx
::
TensorProto_DataType_INT64
,
kNumberTypeInt64
},
{
onnx
::
TensorProto_DataType_UINT8
,
kNumberTypeUInt8
},
{
onnx
::
TensorProto_DataType_UINT16
,
kNumberTypeUInt16
},
{
onnx
::
TensorProto_DataType_UINT32
,
kNumberTypeUInt32
},
{
onnx
::
TensorProto_DataType_UINT64
,
kNumberTypeUInt64
},
{
onnx
::
TensorProto_DataType_FLOAT16
,
kNumberTypeFloat16
},
{
onnx
::
TensorProto_DataType_FLOAT
,
kNumberTypeFloat32
},
{
onnx
::
TensorProto_DataType_DOUBLE
,
kNumberTypeFloat64
},
{
onnx
::
TensorProto_DataType_STRING
,
kObjectTypeString
},
{
onnx
::
TensorProto_DataType_BOOL
,
kNumberTypeBool
},
{
onnx
::
TensorProto_DataType_INT8
,
kNumberTypeInt8
},
{
onnx
::
TensorProto_DataType_INT16
,
kNumberTypeInt16
},
{
onnx
::
TensorProto_DataType_INT32
,
kNumberTypeInt32
},
{
onnx
::
TensorProto_DataType_INT64
,
kNumberTypeInt64
},
{
onnx
::
TensorProto_DataType_UINT8
,
kNumberTypeUInt8
},
{
onnx
::
TensorProto_DataType_UINT16
,
kNumberTypeUInt16
},
{
onnx
::
TensorProto_DataType_UINT32
,
kNumberTypeUInt32
},
{
onnx
::
TensorProto_DataType_UINT64
,
kNumberTypeUInt64
},
{
onnx
::
TensorProto_DataType_FLOAT16
,
kNumberTypeFloat16
},
{
onnx
::
TensorProto_DataType_FLOAT
,
kNumberTypeFloat32
},
{
onnx
::
TensorProto_DataType_DOUBLE
,
kNumberTypeFloat64
},
{
onnx
::
TensorProto_DataType_STRING
,
kObjectTypeString
},
};
#if 0
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
...
...
@@ -190,16 +194,17 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a
return {};
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return{}; \
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype( \
const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return {}; \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
...
...
@@ -634,8 +639,508 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype( \
const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<ValuePtr> attr_value_vec; \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
} \
}
PARSE_ONNXATTR_IN_SCALAR_FORM
(
double
,
double
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
float
,
float
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
string
,
string
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
int32
,
int32
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
int32
,
bool
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
int64
,
int64
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
uint64
,
uint64
)
bool
AnfImporterFromProtobuf
::
BuildParameterForFuncGraph
(
const
ParameterPtr
&
node
,
const
onnx
::
ValueInfoProto
&
value_proto
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
value_proto
.
has_type
()
||
!
value_proto
.
has_name
())
{
MS_LOG
(
ERROR
)
<<
"onnx ValueInfoProto has no type or name! "
;
return
false
;
}
node
->
set_name
(
value_proto
.
name
());
const
auto
&
type_proto
=
value_proto
.
type
();
if
(
!
type_proto
.
has_tensor_type
())
{
MS_LOG
(
ERROR
)
<<
"onnx TypeProto has no tesor_type! "
;
return
false
;
}
const
onnx
::
TypeProto_Tensor
&
tensor_typeproto
=
type_proto
.
tensor_type
();
if
(
!
tensor_typeproto
.
has_elem_type
()
||
!
tensor_typeproto
.
has_shape
())
{
MS_LOG
(
ERROR
)
<<
"onnx TypeProto_Tensor has no elem_type or shape! "
;
return
false
;
}
const
onnx
::
TensorShapeProto
&
tensor_shape
=
tensor_typeproto
.
shape
();
std
::
vector
<
int
>
shape
;
for
(
int
i
=
0
;
i
<
tensor_shape
.
dim_size
();
++
i
)
{
shape
.
push_back
(
tensor_shape
.
dim
(
i
).
dim_value
());
}
if
(
kDefaultValueSwitchMap
.
find
(
tensor_typeproto
.
elem_type
())
==
kDefaultValueSwitchMap
.
end
())
{
MS_LOG
(
ERROR
)
<<
"onnx TypeProto_Tensor elem_type is not support yet!"
;
return
false
;
}
auto
type_ptr
=
TypeIdToType
(
kDefaultValueSwitchMap
[
tensor_typeproto
.
elem_type
()]);
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
shape
);
node
->
set_abstract
(
abstract_tensor
);
if
(
default_para_map_
.
find
(
value_proto
.
name
())
!=
default_para_map_
.
end
())
{
tensor
::
Tensor
*
tensor_info
=
new
tensor
::
Tensor
(
kDefaultValueSwitchMap
[
tensor_typeproto
.
elem_type
()],
shape
);
MS_EXCEPTION_IF_NULL
(
tensor_info
);
tensor_info
->
MallocData
();
const
onnx
::
TensorProto
initialize_proto
=
default_para_map_
[
value_proto
.
name
()];
std
::
string
initial_data
=
initialize_proto
.
raw_data
();
auto
*
tensor_data_buf
=
reinterpret_cast
<
uint8_t
*>
(
tensor_info
->
Data
());
MS_EXCEPTION_IF_NULL
(
tensor_data_buf
);
memcpy_s
(
tensor_data_buf
,
tensor_info
->
Size
(),
initial_data
.
data
(),
initial_data
.
size
());
ParamValueLitePtr
param_value
=
std
::
make_shared
<
ParamValueLite
>
();
MS_EXCEPTION_IF_NULL
(
param_value
);
param_value
->
set_tensor_addr
(
tensor_data_buf
);
param_value
->
set_tensor_size
(
tensor_info
->
Size
());
node
->
set_default_param
(
param_value
);
}
anfnode_build_map_
[
value_proto
.
name
()]
=
node
;
return
true
;
}
bool
AnfImporterFromProtobuf
::
ImportParametersForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
MS_LOG
(
INFO
)
<<
"Parameters had default paramerer size is: "
<<
importProto
.
initializer_size
();
for
(
int
i
=
0
;
i
<
importProto
.
initializer_size
();
++
i
)
{
const
onnx
::
TensorProto
&
initializer_proto
=
importProto
.
initializer
(
i
);
if
(
!
initializer_proto
.
has_name
())
{
MS_LOG
(
ERROR
)
<<
"initializer vector of onnx GraphProto has no name at index: "
<<
i
;
return
false
;
}
default_para_map_
[
initializer_proto
.
name
()]
=
initializer_proto
;
}
MS_LOG
(
INFO
)
<<
"all parameters size: "
<<
importProto
.
input_size
();
for
(
int
i
=
0
;
i
<
importProto
.
input_size
();
++
i
)
{
const
onnx
::
ValueInfoProto
&
input_proto
=
importProto
.
input
(
i
);
if
(
!
BuildParameterForFuncGraph
(
outputFuncGraph
->
add_parameter
(),
input_proto
))
{
MS_LOG
(
ERROR
)
<<
"Build parameter for funcgraph fail at index: "
<<
i
;
return
false
;
}
}
return
true
;
}
bool
AnfImporterFromProtobuf
::
ObtainCNodeAttrInTypeForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
if
(
kDefaultValueSwitchMap
.
find
(
attr_tensor_type
)
==
kDefaultValueSwitchMap
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Obtain attr in type-form has not support input type:"
<<
attr_tensor_type
;
return
false
;
}
prim
->
AddAttr
(
attr_name
,
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor_type
]));
return
true
;
}
bool
AnfImporterFromProtobuf
::
ObtainCNodeAttrInScalarForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
switch
(
attr_tensor_type
)
{
case
onnx
::
TensorProto_DataType_STRING
:
{
ParseAttrInScalar_string_string
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_INT32
:
{
ParseAttrInScalar_int32_int32
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_INT64
:
{
ParseAttrInScalar_int64_int64
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_UINT64
:
{
ParseAttrInScalar_uint64_uint64
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_FLOAT
:
{
ParseAttrInScalar_float_float
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_DOUBLE
:
{
ParseAttrInScalar_double_double
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
onnx
::
TensorProto_DataType_BOOL
:
{
ParseAttrInScalar_int32_bool
(
prim
,
attr_name
,
attr_tensor
);
auto
value
=
prim
->
GetAttr
(
attr_name
);
break
;
}
default:
MS_LOG
(
ERROR
)
<<
"Obtain attr in scalar-form has not support input type: "
<<
attr_tensor_type
;
return
false
;
}
return
true
;
}
bool
AnfImporterFromProtobuf
::
ObtainCNodeAttrInTensorForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
MS_LOG
(
ERROR
)
<<
"parse attr type don't support attr type is tensor"
;
return
false
;
}
bool
AnfImporterFromProtobuf
::
GetAttrValueForCNode
(
const
PrimitivePtr
&
prim
,
const
onnx
::
AttributeProto
&
attr_proto
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
const
std
::
string
&
attr_name
=
attr_proto
.
name
();
if
(
!
attr_proto
.
has_ref_attr_name
())
{
MS_LOG
(
ERROR
)
<<
"CNode parse attr type has no ref_attr_name"
;
return
false
;
}
const
std
::
string
&
ref_attr_name
=
attr_proto
.
ref_attr_name
();
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
switch
(
kParseTypeSwitchMap
[
ref_attr_name
])
{
case
FORM_PARSE_TYPE
:
{
return
ObtainCNodeAttrInTypeForm
(
prim
,
attr_name
,
attr_tensor
);
}
case
FORM_PARSE_SCALAR
:
{
return
ObtainCNodeAttrInScalarForm
(
prim
,
attr_name
,
attr_tensor
);
}
case
FORM_PARSE_TENSOR
:
{
return
ObtainCNodeAttrInTensorForm
(
prim
,
attr_name
,
attr_tensor
);
}
default:
MS_LOG
(
ERROR
)
<<
"parse attr type don't support input of ref_attr_name"
;
return
false
;
}
}
bool
AnfImporterFromProtobuf
::
ObtainValueNodeInTensorForm
(
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
std
::
vector
<
int
>
shape
;
for
(
int
i
=
0
;
i
<
attr_tensor
.
dims_size
();
++
i
)
{
shape
.
push_back
(
attr_tensor
.
dims
(
i
));
}
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
attr_tensor_type
],
shape
);
tensor_info
->
MallocData
();
const
std
::
string
&
tensor_buf
=
attr_tensor
.
raw_data
();
auto
*
tensor_data_buf
=
reinterpret_cast
<
uint8_t
*>
(
tensor_info
->
Data
());
memcpy_s
(
tensor_data_buf
,
tensor_info
->
Size
(),
tensor_buf
.
data
(),
tensor_buf
.
size
());
auto
new_value_node
=
NewValueNode
(
MakeValue
(
tensor_info
));
MS_EXCEPTION_IF_NULL
(
new_value_node
);
auto
type_ptr
=
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor_type
]);
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
shape
);
new_value_node
->
set_abstract
(
abstract_tensor
);
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
return
true
;
}
bool
AnfImporterFromProtobuf
::
ObtainValueNodeInScalarForm
(
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
ValuePtr
value_ptr
=
nullptr
;
switch
(
attr_tensor_type
)
{
case
onnx
::
TensorProto_DataType_INT32
:
{
std
::
vector
<
int32
>
add_data
;
for
(
int
i
=
0
;
i
<
attr_tensor
.
int32_data_size
();
++
i
)
{
add_data
.
push_back
(
attr_tensor
.
int32_data
(
i
));
}
if
(
add_data
.
size
()
==
1
)
{
value_ptr
=
MakeValue
(
add_data
[
0
]);
}
else
if
(
!
add_data
.
empty
())
{
value_ptr
=
MakeValue
<
std
::
vector
<
int32
>>
(
add_data
);
}
break
;
}
case
onnx
::
TensorProto_DataType_FLOAT
:
{
std
::
vector
<
float
>
add_data
;
for
(
int
i
=
0
;
i
<
attr_tensor
.
float_data_size
();
++
i
)
{
add_data
.
push_back
(
attr_tensor
.
float_data
(
i
));
}
if
(
add_data
.
size
()
==
1
)
{
value_ptr
=
MakeValue
(
add_data
[
0
]);
}
else
if
(
!
add_data
.
empty
())
{
value_ptr
=
MakeValue
<
std
::
vector
<
float
>>
(
add_data
);
}
break
;
}
case
onnx
::
TensorProto_DataType_UNDEFINED
:
{
std
::
vector
<
ValuePtr
>
elems
;
value_ptr
=
std
::
make_shared
<
ValueTuple
>
(
elems
);
break
;
}
default:
MS_LOG
(
ERROR
)
<<
"Obtain attr in scalar-form has not support input type: "
<<
attr_tensor_type
;
return
false
;
}
auto
new_value_node
=
NewValueNode
(
value_ptr
);
MS_EXCEPTION_IF_NULL
(
new_value_node
);
new_value_node
->
set_abstract
(
value_ptr
->
ToAbstract
());
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
return
true
;
}
bool
AnfImporterFromProtobuf
::
ObtainValueNodeInTypeForm
(
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
if
(
kDefaultValueSwitchMap
.
find
(
attr_tensor_type
)
==
kDefaultValueSwitchMap
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Obtain ValueNode attr in type-form has not support input type: "
<<
attr_tensor_type
;
return
false
;
}
auto
new_value_node
=
NewValueNode
(
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor_type
]));
abstract
::
AbstractTypePtr
abs_type
=
std
::
make_shared
<
abstract
::
AbstractType
>
(
std
::
make_shared
<
TypeType
>
());
new_value_node
->
set_abstract
(
abs_type
);
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
return
true
;
}
bool
AnfImporterFromProtobuf
::
GetAttrValueForValueNode
(
const
std
::
string
&
ref_attr_name
,
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
switch
(
kParseTypeSwitchMap
[
ref_attr_name
])
{
case
FORM_PARSE_SCALAR
:
{
return
ObtainValueNodeInScalarForm
(
value_node_name
,
attr_tensor
);
}
case
FORM_PARSE_TENSOR
:
{
return
ObtainValueNodeInTensorForm
(
value_node_name
,
attr_tensor
);
}
case
FORM_PARSE_TYPE
:
{
return
ObtainValueNodeInTypeForm
(
value_node_name
,
attr_tensor
);
}
default:
MS_LOG
(
ERROR
)
<<
"parse ValueNode value don't support input of ref_attr_name"
;
return
false
;
}
}
bool
AnfImporterFromProtobuf
::
BuildValueNodeForFuncGraph
(
const
onnx
::
NodeProto
&
node_proto
)
{
const
std
::
string
&
value_node_name
=
node_proto
.
output
(
0
);
const
onnx
::
AttributeProto
&
attr_proto
=
node_proto
.
attribute
(
0
);
if
(
!
attr_proto
.
has_ref_attr_name
())
{
MS_LOG
(
ERROR
)
<<
"parse ValueNode don't have ref_attr_name"
;
return
false
;
}
const
std
::
string
&
ref_attr_name
=
attr_proto
.
ref_attr_name
();
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
return
GetAttrValueForValueNode
(
ref_attr_name
,
value_node_name
,
attr_tensor
);
}
abstract
::
AbstractTensorPtr
AnfImporterFromProtobuf
::
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
)
{
std
::
vector
<
int
>
shape_vec
;
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
for
(
int
i
=
0
;
i
<
attr_tensor
.
dims_size
();
++
i
)
{
shape_vec
.
push_back
(
attr_tensor
.
dims
(
i
));
}
auto
type_ptr
=
TypeIdToType
(
kDefaultValueSwitchMap
[
attr_tensor
.
data_type
()]);
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
shape_vec
);
MS_EXCEPTION_IF_NULL
(
abstract_tensor
);
return
abstract_tensor
;
}
CNodePtr
AnfImporterFromProtobuf
::
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
if
(
!
node_proto
.
has_op_type
())
{
MS_LOG
(
ERROR
)
<<
"Get CNode op_type failed!"
;
return
nullptr
;
}
const
std
::
string
&
node_name
=
node_proto
.
output
(
0
);
const
std
::
string
&
fullname_with_scope
=
node_proto
.
domain
();
const
std
::
string
&
node_type
=
node_proto
.
op_type
();
PrimitivePtr
prim
=
std
::
make_shared
<
Primitive
>
(
node_type
);
MS_EXCEPTION_IF_NULL
(
prim
);
prim
->
set_instance_name
(
node_type
);
abstract
::
AbstractTensorPtr
abstract
=
nullptr
;
abstract
::
AbstractTensorPtr
abstract_first
=
nullptr
;
abstract
::
AbstractTensorPtr
abstract_second
=
nullptr
;
for
(
int
i
=
0
;
i
<
node_proto
.
attribute_size
();
++
i
)
{
const
onnx
::
AttributeProto
&
attr_proto
=
node_proto
.
attribute
(
i
);
if
(
attr_proto
.
name
()
==
kCNodeShapeAttr
)
{
abstract
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape1Attr
)
{
abstract_first
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape2Attr
)
{
abstract_second
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
!
GetAttrValueForCNode
(
prim
,
attr_proto
))
{
MS_LOG
(
ERROR
)
<<
"Get CNode attr failed!"
;
return
nullptr
;
}
}
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
));
for
(
int
i
=
0
;
i
<
node_proto
.
input_size
();
++
i
)
{
const
std
::
string
&
input_name
=
node_proto
.
input
(
i
);
if
(
anfnode_build_map_
.
find
(
input_name
)
==
anfnode_build_map_
.
end
())
{
MS_LOG
(
ERROR
)
<<
node_name
<<
" input "
<<
i
<<
input_name
<<
"can't find in nodes have parsed"
;
return
nullptr
;
}
inputs
.
push_back
(
anfnode_build_map_
[
input_name
]);
}
CNodePtr
cnode_ptr
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
if
(
node_type
==
"LayerNorm"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
elem
.
push_back
(
abstract_second
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
node_type
==
"ArgMaxWithValue"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
nullptr
==
abstract
)
{
AbstractBasePtrList
elem
;
for
(
size_t
index
=
1
;
index
<
cnode_ptr
->
inputs
().
size
();
++
index
)
{
elem
.
push_back
(
cnode_ptr
->
input
(
index
)
->
abstract
());
}
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
{
cnode_ptr
->
set_abstract
(
abstract
);
}
cnode_ptr
->
set_fullname_with_scope
(
fullname_with_scope
);
anfnode_build_map_
[
node_name
]
=
cnode_ptr
;
return
cnode_ptr
;
}
bool
AnfImporterFromProtobuf
::
BuildReturnForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
,
const
CNodePtr
&
cnode_ptr
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
std
::
vector
<
AnfNodePtr
>
inputs
;
if
(
importProto
.
output_size
()
>
1
)
{
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
AbstractBasePtrList
elem
;
for
(
int
out_size
=
0
;
out_size
<
importProto
.
output_size
();
++
out_size
)
{
const
onnx
::
ValueInfoProto
&
output_node
=
importProto
.
output
(
out_size
);
const
std
::
string
&
out_tuple
=
output_node
.
name
();
inputs
.
push_back
(
anfnode_build_map_
[
out_tuple
]);
elem
.
push_back
(
anfnode_build_map_
[
out_tuple
]
->
abstract
());
}
auto
maketuple_ptr
=
outputFuncGraph
->
NewCNode
(
inputs
);
maketuple_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
inputs
.
push_back
(
maketuple_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
outputFuncGraph
->
set_return
(
return_node
);
MS_LOG
(
INFO
)
<<
"Construct funcgraph finined, all success."
;
}
else
{
const
onnx
::
ValueInfoProto
&
output_node
=
importProto
.
output
(
0
);
const
onnx
::
TypeProto
&
output_typeproto
=
output_node
.
type
();
int
output_type
=
output_typeproto
.
tensor_type
().
elem_type
();
std
::
vector
<
int
>
output_shape
;
for
(
int
i
=
0
;
i
<
output_typeproto
.
tensor_type
().
shape
().
dim_size
();
++
i
)
{
output_shape
.
push_back
(
output_typeproto
.
tensor_type
().
shape
().
dim
(
i
).
dim_value
());
}
auto
type_ptr
=
TypeIdToType
(
kDefaultValueSwitchMap
[
output_type
]);
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
output_shape
);
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
inputs
.
push_back
(
cnode_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
return_node
->
set_abstract
(
abstract_tensor
);
outputFuncGraph
->
set_return
(
return_node
);
MS_LOG
(
INFO
)
<<
"Construct funcgraph finined, all success!"
;
}
return
true
;
}
bool
AnfImporterFromProtobuf
::
ImportNodesForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
MS_LOG
(
INFO
)
<<
"The CNdoe size : "
<<
importProto
.
node_size
();
CNodePtr
cnode_ptr
=
nullptr
;
for
(
int
i
=
0
;
i
<
importProto
.
node_size
();
++
i
)
{
const
onnx
::
NodeProto
&
node_proto
=
importProto
.
node
(
i
);
const
std
::
string
&
node_type
=
node_proto
.
op_type
();
if
(
node_type
==
kConstantValueNode
)
{
if
(
!
BuildValueNodeForFuncGraph
(
node_proto
))
{
MS_LOG
(
ERROR
)
<<
"Build ValueNode for funcgraph fail at index: : "
<<
i
;
return
false
;
}
continue
;
}
cnode_ptr
=
BuildCNodeForFuncGraph
(
outputFuncGraph
,
node_proto
);
if
(
cnode_ptr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build CNode for funcgraph fail at index: : "
<<
i
;
return
false
;
}
}
BuildReturnForFuncGraph
(
outputFuncGraph
,
importProto
,
cnode_ptr
);
return
true
;
}
bool
AnfImporterFromProtobuf
::
BuildFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
)
{
bool
AnfImporterFromProtobuf
::
BuildFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
)
{
MS_EXCEPTION_IF_NULL
(
outputFuncGraph
);
GraphDebugInfoPtr
debug_info_ptr
=
outputFuncGraph
->
debug_info
();
MS_EXCEPTION_IF_NULL
(
debug_info_ptr
);
...
...
@@ -651,7 +1156,8 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
return
ImportNodesForGraph
(
outputFuncGraph
,
importProto
);
}
bool
AnfImporterFromProtobuf
::
ParseModelConfigureInfo
(
const
onnx
::
ModelProto
&
model_proto
)
{
bool
AnfImporterFromProtobuf
::
ParseModelConfigureInfo
(
const
onnx
::
ModelProto
&
model_proto
)
{
if
(
!
model_proto
.
has_producer_name
())
{
MS_LOG
(
ERROR
)
<<
"Parse model producer name from pb file failed!"
;
return
false
;
...
...
@@ -672,7 +1178,6 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo
return
true
;
}
int
AnfImporterFromProtobuf
::
Import
()
{
FuncGraphPtr
dstGraph
=
std
::
make_shared
<
mindspore
::
FuncGraph
>
();
MS_EXCEPTION_IF_NULL
(
dstGraph
);
...
...
@@ -689,9 +1194,9 @@ int AnfImporterFromProtobuf::Import() {
return
RET_OK
;
}
onnx
::
ModelProto
*
AnfImporterFromProtobuf
::
ReadOnnxFromBinary
(
const
std
::
string
&
model_path
)
{
std
::
unique_ptr
<
char
>
onnx_file
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
onnx
::
ModelProto
*
AnfImporterFromProtobuf
::
ReadOnnxFromBinary
(
const
std
::
string
&
model_path
)
{
std
::
unique_ptr
<
char
>
onnx_file
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
if
(
realpath
(
model_path
.
c_str
(),
onnx_file
.
get
())
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"open file failed."
;
return
nullptr
;
...
...
@@ -707,11 +1212,10 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string
delete
onnx_model
;
return
nullptr
;
}
(
void
)
close
(
fd
);
(
void
)
close
(
fd
);
MS_LOG
(
INFO
)
<<
"enter ReadProtoFromBinary success!"
<<
std
::
endl
;
return
onnx_model
;
}
FuncGraphPtr
AnfImporterFromProtobuf
::
GetResult
()
{
return
this
->
func_graph_
;
}
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_importer/import_from_protobuf.h
浏览文件 @
04383c66
...
...
@@ -47,6 +47,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool
ParseModelConfigureInfo
(
const
onnx
::
ModelProto
&
model_proto
);
bool
BuildFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
#if 0
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
...
...
@@ -76,6 +77,30 @@ class AnfImporterFromProtobuf : public AnfImporter {
const onnx::TensorProto &attr_tensor);
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
bool
ImportParametersForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
bool
ImportNodesForGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
);
bool
BuildParameterForFuncGraph
(
const
ParameterPtr
&
node
,
const
onnx
::
ValueInfoProto
&
value_proto
);
CNodePtr
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
NodeProto
&
node_proto
);
bool
BuildReturnForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
const
onnx
::
GraphProto
&
importProto
,
const
CNodePtr
&
cnode_ptr
);
bool
GetAttrValueForCNode
(
const
PrimitivePtr
&
prim
,
const
onnx
::
AttributeProto
&
attr_proto
);
bool
ObtainCNodeAttrInTypeForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainCNodeAttrInScalarForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainCNodeAttrInTensorForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
BuildValueNodeForFuncGraph
(
const
onnx
::
NodeProto
&
node_proto
);
bool
ObtainValueNodeInTensorForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainValueNodeInScalarForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
GetAttrValueForValueNode
(
const
string
&
ref_attr_name
,
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainValueNodeInTypeForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
abstract
::
AbstractTensorPtr
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
);
private:
std
::
string
producer_name_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录