Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
15d03657
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
15d03657
编写于
8月 03, 2020
作者:
Y
yankai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix anf exporter
上级
488d1904
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
235 addition
and
21 deletion
+235
-21
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+22
-2
mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc
...f_exporter/anf_populater/anf_depthwiseconv2d_populater.cc
+36
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h
...nf_exporter/anf_populater/anf_depthwiseconv2d_populater.h
+29
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc
...ommon/anf_exporter/anf_populater/anf_dequant_populater.cc
+35
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h
...common/anf_exporter/anf_populater/anf_dequant_populater.h
+29
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc
...rc/common/anf_exporter/anf_populater/anf_mul_populater.cc
+1
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc
...anf_exporter/anf_populater/anf_node_populater_registry.cc
+1
-7
mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h
.../anf_exporter/anf_populater/anf_node_populater_registry.h
+1
-1
mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc
.../common/anf_exporter/anf_populater/anf_quant_populater.cc
+35
-0
mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h
...c/common/anf_exporter/anf_populater/anf_quant_populater.h
+29
-0
mindspore/lite/tools/converter/graphdef_transform.cc
mindspore/lite/tools/converter/graphdef_transform.cc
+2
-2
mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc
...lite/tools/converter/optimizer/graph/format_trans_pass.cc
+1
-1
mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h
.../lite/tools/converter/optimizer/graph/format_trans_pass.h
+1
-1
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
...lite/tools/converter/optimizer/node/weight_format_pass.cc
+12
-6
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h
.../lite/tools/converter/optimizer/node/weight_format_pass.h
+1
-1
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
15d03657
/**
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
*
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "mindspore/core/ir/primitive.h"
#include "mindspore/core/ir/primitive.h"
#include "src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "base/core_ops.h"
#include "base/core_ops.h"
#include "src/ir/tensor.h"
namespace
mindspore
::
lite
{
namespace
mindspore
::
lite
{
schema
::
MetaGraphT
*
AnfExporter
::
Export
(
const
FuncGraphPtr
&
funcGraph
)
{
schema
::
MetaGraphT
*
AnfExporter
::
Export
(
const
FuncGraphPtr
&
funcGraph
)
{
...
@@ -223,9 +224,28 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
...
@@ -223,9 +224,28 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
}
else
if
(
inputNode
->
isa
<
ValueNode
>
())
{
auto
valueNode
=
inputNode
->
cast
<
ValueNodePtr
>
();
auto
paramTensor
=
std
::
make_unique
<
schema
::
TensorT
>
();
auto
value
=
valueNode
->
value
();
if
(
value
->
isa
<
lite
::
tensor
::
Tensor
>
())
{
auto
valueAbstract
=
valueNode
->
abstract
();
auto
abstractTensor
=
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
valueAbstract
);
auto
typePtr
=
abstractTensor
->
element
()
->
GetTypeTrack
();
paramTensor
->
dataType
=
typePtr
->
type_id
();
paramTensor
->
dims
=
utils
::
cast
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
())
->
shape
();
paramTensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
auto
data
=
value
->
cast
<
lite
::
tensor
::
TensorPtr
>
();
paramTensor
->
data
.
resize
(
data
->
Size
());
memcpy
(
paramTensor
->
data
.
data
(),
data
->
Data
(),
data
->
Size
());
nodeIdMap
[
valueNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
}
else
{
MS_LOG
(
ERROR
)
<<
"Not support value type , need add support."
;
}
}
}
}
}
if
(
isGraphInput
)
{
if
(
isGraphInput
)
{
graphInputNodes
.
emplace_back
(
fbNode
);
graphInputNodes
.
emplace_back
(
fbNode
);
}
}
...
...
mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace
mindspore
::
lite
{
int
mindspore
::
lite
::
AnfDepwiseconv2DPopulater
::
Parse
(
mindspore
::
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
DepthwiseConv2DT
>
();
node
->
nodeType
=
schema
::
NodeType_CNode
;
node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
node
->
primitive
->
value
.
value
=
attr
.
release
();
return
0
;
}
AnfNodePopulaterRegistrar
anfdepthwise2dParser
(
"DepthwiseConv2D"
,
new
AnfDepwiseconv2DPopulater
());
AnfNodePopulaterRegistrar
anfdepthwise2dnativeParser
(
"DepthwiseConv2dNative"
,
new
AnfDepwiseconv2DPopulater
());
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace
mindspore
::
lite
{
class
AnfDepwiseconv2DPopulater
:
public
AnfNodePopulater
{
public:
AnfDepwiseconv2DPopulater
()
=
default
;
~
AnfDepwiseconv2DPopulater
()
override
=
default
;
int
Parse
(
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
override
;
};
}
// namespace mindspore::lite
#endif // MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace
mindspore
::
lite
{
int
mindspore
::
lite
::
AnfDequantPopulater
::
Parse
(
mindspore
::
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
OnnxInt8DequantizeT
>
();
node
->
nodeType
=
schema
::
NodeType_CNode
;
node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_OnnxInt8Dequantize
;
node
->
primitive
->
value
.
value
=
attr
.
release
();
return
0
;
}
AnfNodePopulaterRegistrar
anfDequantParser
(
"Dequant"
,
new
AnfDequantPopulater
());
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
#define MINDSPORE_ANF_DEQUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace
mindspore
::
lite
{
class
AnfDequantPopulater
:
public
AnfNodePopulater
{
public:
AnfDequantPopulater
()
=
default
;
~
AnfDequantPopulater
()
override
=
default
;
int
Parse
(
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
override
;
};
}
// namespace mindspore::lite
#endif // MINDSPORE_ANF_DEQUANT_PARSER_H
mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc
浏览文件 @
15d03657
...
@@ -32,4 +32,5 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema
...
@@ -32,4 +32,5 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema
return
0
;
return
0
;
}
}
AnfNodePopulaterRegistrar
anfMulParser
(
"Mul"
,
new
AnfMulPopulater
());
AnfNodePopulaterRegistrar
anfMulParser
(
"Mul"
,
new
AnfMulPopulater
());
AnfNodePopulaterRegistrar
anfMatMulParser
(
"MatMul"
,
new
AnfMulPopulater
());
}
// namespace mindspore::lite
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc
浏览文件 @
15d03657
/**
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
@@ -26,12 +26,6 @@ namespace mindspore {
...
@@ -26,12 +26,6 @@ namespace mindspore {
namespace
lite
{
namespace
lite
{
AnfNodePopulaterRegistry
*
AnfNodePopulaterRegistry
::
GetInstance
()
{
AnfNodePopulaterRegistry
*
AnfNodePopulaterRegistry
::
GetInstance
()
{
static
AnfNodePopulaterRegistry
instance
;
static
AnfNodePopulaterRegistry
instance
;
instance
.
SetNodePopulater
(
"BiasAdd"
,
new
AnfBiasAddPopulater
());
instance
.
SetNodePopulater
(
"Conv2D"
,
new
AnfConvPopulater
());
instance
.
SetNodePopulater
(
"MatMul"
,
new
AnfMatmulPopulater
());
instance
.
SetNodePopulater
(
"MaxPool"
,
new
AnfPoolPopulater
());
instance
.
SetNodePopulater
(
"ReLU"
,
new
AnfActivationPopulater
());
instance
.
SetNodePopulater
(
"Flatten"
,
new
AnfFlattenPopulater
());
return
&
instance
;
return
&
instance
;
}
}
AnfNodePopulater
*
AnfNodePopulaterRegistry
::
GetNodePopulater
(
const
std
::
string
&
name
)
{
AnfNodePopulater
*
AnfNodePopulaterRegistry
::
GetNodePopulater
(
const
std
::
string
&
name
)
{
...
...
mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h
浏览文件 @
15d03657
/**
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace
mindspore
::
lite
{
int
mindspore
::
lite
::
AnfQuantPopulater
::
Parse
(
mindspore
::
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
OnnxInt8QuantizeT
>
();
node
->
nodeType
=
schema
::
NodeType_CNode
;
node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_OnnxInt8Quantize
;
node
->
primitive
->
value
.
value
=
attr
.
release
();
return
0
;
}
AnfNodePopulaterRegistrar
anfQuantParser
(
"Quant"
,
new
AnfQuantPopulater
());
}
// namespace mindspore::lite
mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h
0 → 100644
浏览文件 @
15d03657
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_QUANT_PARSER_H
#define MINDSPORE_ANF_QUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace
mindspore
::
lite
{
class
AnfQuantPopulater
:
public
AnfNodePopulater
{
public:
AnfQuantPopulater
()
=
default
;
~
AnfQuantPopulater
()
override
=
default
;
int
Parse
(
CNodePtr
cnodePtr
,
schema
::
CNodeT
*
node
,
std
::
vector
<
schema
::
TensorT
*>
*
outputs
)
override
;
};
}
// namespace mindspore::lite
#endif // MINDSPORE_ANF_QUANT_PARSER_H
mindspore/lite/tools/converter/graphdef_transform.cc
浏览文件 @
15d03657
...
@@ -123,7 +123,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
...
@@ -123,7 +123,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
MS_LOG
(
ERROR
)
<<
"new weightFormatPass failed"
;
MS_LOG
(
ERROR
)
<<
"new weightFormatPass failed"
;
return
RET_ERROR
;
return
RET_ERROR
;
}
}
//
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass
->
SetQuantType
(
ctx
.
quantType
);
weightFormatPass
->
SetFmkType
(
ctx
.
fmk
);
weightFormatPass
->
SetFmkType
(
ctx
.
fmk
);
weightFormatOptimizer
.
AddPass
(
weightFormatPass
);
weightFormatOptimizer
.
AddPass
(
weightFormatPass
);
status
=
weightFormatOptimizer
.
Run
(
graphDefT
);
status
=
weightFormatOptimizer
.
Run
(
graphDefT
);
...
@@ -141,7 +141,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
...
@@ -141,7 +141,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
MS_LOG
(
ERROR
)
<<
"new formatTransPass failed"
;
MS_LOG
(
ERROR
)
<<
"new formatTransPass failed"
;
return
RET_ERROR
;
return
RET_ERROR
;
}
}
//
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass
->
SetQuantType
(
ctx
.
quantType
);
formatTransPass
->
SetFmk
(
ctx
.
fmk
);
formatTransPass
->
SetFmk
(
ctx
.
fmk
);
formatTransOptimizer
.
AddPass
(
formatTransPass
);
formatTransOptimizer
.
AddPass
(
formatTransPass
);
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
FormatTransFusionPass
());
formatTransOptimizer
.
AddPass
(
new
(
std
::
nothrow
)
FormatTransFusionPass
());
...
...
mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc
浏览文件 @
15d03657
...
@@ -191,7 +191,7 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
...
@@ -191,7 +191,7 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
return
InsertNode
(
graph
,
existNodeIter
,
place
,
inoutIdx
,
std
::
move
(
transNode
),
errorCode
);
return
InsertNode
(
graph
,
existNodeIter
,
place
,
inoutIdx
,
std
::
move
(
transNode
),
errorCode
);
}
}
//
void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void
FormatTransPass
::
SetQuantType
(
QuantType
quantType
)
{
this
->
quantType
=
quantType
;
}
void
FormatTransPass
::
SetFmk
(
converter
::
FmkType
fmkType
)
{
this
->
fmkType
=
fmkType
;
}
void
FormatTransPass
::
SetFmk
(
converter
::
FmkType
fmkType
)
{
this
->
fmkType
=
fmkType
;
}
...
...
mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h
浏览文件 @
15d03657
...
@@ -33,7 +33,7 @@ class FormatTransPass : public GraphPass {
...
@@ -33,7 +33,7 @@ class FormatTransPass : public GraphPass {
STATUS
Run
(
schema
::
MetaGraphT
*
graph
)
override
;
STATUS
Run
(
schema
::
MetaGraphT
*
graph
)
override
;
//
void SetQuantType(QuantType quantType);
void
SetQuantType
(
QuantType
quantType
);
void
SetFmk
(
converter
::
FmkType
fmkType
);
void
SetFmk
(
converter
::
FmkType
fmkType
);
...
...
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc
浏览文件 @
15d03657
...
@@ -43,7 +43,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
...
@@ -43,7 +43,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
return
0
;
return
0
;
}
}
//
void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void
WeightFormatPass
::
SetQuantType
(
QuantType
quantType
)
{
this
->
quantType
=
quantType
;
}
void
WeightFormatPass
::
SetFmkType
(
converter
::
FmkType
fmkType
)
{
this
->
fmkType
=
fmkType
;
}
void
WeightFormatPass
::
SetFmkType
(
converter
::
FmkType
fmkType
)
{
this
->
fmkType
=
fmkType
;
}
...
@@ -223,11 +223,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -223,11 +223,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto
weightIndex
=
node
->
inputIndex
.
at
(
1
);
auto
weightIndex
=
node
->
inputIndex
.
at
(
1
);
MS_ASSERT
(
subGraph
->
allTensors
.
size
()
>
weightIndex
);
MS_ASSERT
(
subGraph
->
allTensors
.
size
()
>
weightIndex
);
auto
&
weightTensor
=
subGraph
->
allTensors
[
weightIndex
];
auto
&
weightTensor
=
subGraph
->
allTensors
[
weightIndex
];
MS_ASSERT
(
weightTensor
->
dataType
==
-
22
);
// DataType_DT_FLOAT
MS_ASSERT
(
weightTensor
->
dataType
==
kNumberTypeInt8
);
// DataType_DT_FLOAT
STATUS
status
;
STATUS
status
;
if
(
opType
==
schema
::
PrimitiveType_Conv2D
)
{
// weight should be HWCK
if
(
opType
==
schema
::
PrimitiveType_Conv2D
)
{
// weight should be HWCK
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
// from caffe
if
(
weightTensor
->
dataType
==
-
22
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
<<
weightTensor
->
format
<<
weightTensor
->
dataType
;
<<
weightTensor
->
dataType
;
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
...
@@ -237,7 +237,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -237,7 +237,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
}
}
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KHWC
)
{
// from onnx
if
(
weightTensor
->
dataType
==
-
22
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
}
else
{
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKHWC2HWCK
);
...
@@ -259,7 +259,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -259,7 +259,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
}
}
else
if
(
opType
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
// weight should be HWCK
}
else
if
(
opType
==
schema
::
PrimitiveType_DepthwiseConv2D
)
{
// weight should be HWCK
if
(
weightTensor
->
format
==
schema
::
Format_CKHW
)
{
// from caffe
if
(
weightTensor
->
format
==
schema
::
Format_CKHW
)
{
// from caffe
if
(
weightTensor
->
dataType
==
-
22
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
,
weightTensor
->
format
,
MS_LOG
(
DEBUG
)
<<
"**weight tensor index: %d, format: %d, datatype: "
<<
weightIndex
,
weightTensor
->
format
,
weightTensor
->
dataType
;
weightTensor
->
dataType
;
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kCKHW2HWCK
);
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kCKHW2HWCK
);
...
@@ -272,11 +272,17 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
...
@@ -272,11 +272,17 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWCK
)
{
// from tf
}
else
if
(
weightTensor
->
format
==
schema
::
Format_HWCK
)
{
// from tf
return
0
;
return
0
;
}
else
if
(
weightTensor
->
format
==
schema
::
Format_CHWK
)
{
// from onnx
}
else
if
(
weightTensor
->
format
==
schema
::
Format_CHWK
)
{
// from onnx
if
(
weightTensor
->
dataType
==
-
22
)
{
// DataType_DT_UINT8) {
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
}
else
{
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kCHWK2HWCK
);
}
}
}
else
if
(
weightTensor
->
format
==
schema
::
Format_KCHW
)
{
if
(
weightTensor
->
dataType
==
kNumberTypeInt8
)
{
// DataType_DT_UINT8) {
status
=
TransFilterFormat
<
uint8_t
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
}
else
{
status
=
TransFilterFormat
<
float
>
(
weightTensor
.
get
(),
kKCHW2HWCK
);
}
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"Unsupported weightTensor format: "
<<
weightTensor
->
format
;
MS_LOG
(
ERROR
)
<<
"Unsupported weightTensor format: "
<<
weightTensor
->
format
;
return
-
1
;
return
-
1
;
...
...
mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h
浏览文件 @
15d03657
...
@@ -29,7 +29,7 @@ class WeightFormatPass : public NodePass {
...
@@ -29,7 +29,7 @@ class WeightFormatPass : public NodePass {
~
WeightFormatPass
()
override
=
default
;
~
WeightFormatPass
()
override
=
default
;
//
void SetQuantType(QuantType quantType);
void
SetQuantType
(
QuantType
quantType
);
void
SetFmkType
(
converter
::
FmkType
fmkType
);
void
SetFmkType
(
converter
::
FmkType
fmkType
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录