Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
989a7e9a
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
989a7e9a
编写于
8月 20, 2020
作者:
K
kai00
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
check memory fixed
上级
b3fac0cd
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
104 addition
and
61 deletion
+104
-61
mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc
...anf_importer/anf_populater/anf_node_populater_registry.cc
+8
-0
mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h
.../anf_importer/anf_populater/anf_node_populater_registry.h
+1
-1
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
+2
-0
mindspore/lite/tools/converter/converter.cc
mindspore/lite/tools/converter/converter.cc
+13
-17
mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc
.../tools/converter/parser/caffe/caffe_convolution_parser.cc
+8
-5
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
...lite/tools/converter/parser/tflite/tflite_model_parser.cc
+1
-1
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
+8
-8
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
+36
-24
mindspore/lite/tools/converter/quantizer/calc_quant_param.h
mindspore/lite/tools/converter/quantizer/calc_quant_param.h
+1
-1
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
...re/lite/tools/optimizer/fusion/constant_folding_fusion.cc
+23
-2
mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc
...pore/lite/tools/optimizer/fusion/conv_transform_fusion.cc
+3
-2
未找到文件。
mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc
浏览文件 @
989a7e9a
...
...
@@ -18,6 +18,14 @@
#include <string>
namespace
mindspore
{
namespace
lite
{
AnfNodePopulaterRegistry
::~
AnfNodePopulaterRegistry
()
{
for
(
auto
ite
:
populaters
)
{
if
(
ite
.
second
!=
nullptr
)
{
delete
ite
.
second
;
ite
.
second
=
nullptr
;
}
}
}
AnfNodePopulaterRegistry
*
AnfNodePopulaterRegistry
::
GetInstance
()
{
static
AnfNodePopulaterRegistry
instance
;
return
&
instance
;
...
...
mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h
浏览文件 @
989a7e9a
...
...
@@ -23,7 +23,7 @@ namespace mindspore::lite {
class
AnfNodePopulaterRegistry
{
public:
AnfNodePopulaterRegistry
()
=
default
;
virtual
~
AnfNodePopulaterRegistry
()
=
default
;
virtual
~
AnfNodePopulaterRegistry
();
static
AnfNodePopulaterRegistry
*
GetInstance
();
AnfNodePopulater
*
GetNodePopulater
(
const
std
::
string
&
name
);
void
SetNodePopulater
(
const
std
::
string
&
name
,
AnfNodePopulater
*
populater
);
...
...
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
浏览文件 @
989a7e9a
...
...
@@ -140,6 +140,8 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
auto
ret
=
memcpy_s
(
tensor_data_buf
,
tensor_info
->
Size
(),
initial_data
.
data
(),
initial_data
.
size
());
if
(
EOK
!=
ret
)
{
MS_LOG
(
ERROR
)
<<
"memcpy_s error"
;
delete
tensor_data_buf
;
delete
tensor_info
;
return
false
;
}
...
...
mindspore/lite/tools/converter/converter.cc
浏览文件 @
989a7e9a
...
...
@@ -43,18 +43,10 @@ Converter::Converter() {
}
Converter
::~
Converter
()
{
if
(
nullptr
!=
modelParser
)
{
delete
modelParser
;
}
if
(
nullptr
!=
modelImporter
)
{
delete
modelImporter
;
}
if
(
nullptr
!=
transform
)
{
delete
transform
;
}
if
(
nullptr
!=
anfTransform
)
{
delete
anfTransform
;
}
delete
modelParser
;
delete
modelImporter
;
delete
transform
;
delete
anfTransform
;
}
class
MindsporeImporter
:
public
Converter
{
...
...
@@ -154,7 +146,11 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
}
}
int
RunConverter
(
int
argc
,
const
char
**
argv
)
{
auto
flags
=
new
converter
::
Flags
;
std
::
unique_ptr
<
converter
::
Flags
>
flags
(
new
(
std
::
nothrow
)
converter
::
Flags
);
if
(
flags
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new flags error "
;
return
RET_ERROR
;
}
auto
status
=
flags
->
Init
(
argc
,
argv
);
if
(
status
==
RET_SUCCESS_EXIT
)
{
return
0
;
...
...
@@ -173,20 +169,20 @@ int RunConverter(int argc, const char **argv) {
auto
graph
=
std
::
make_shared
<
FuncGraph
>
();
auto
onnx_graph
=
AnfImporterFromProtobuf
::
ReadOnnxFromBinary
(
flags
->
modelFile
);
MindsporeImporter
mindsporeImporter
(
onnx_graph
,
graph
);
fb_graph
=
mindsporeImporter
.
Convert
(
flags
);
fb_graph
=
mindsporeImporter
.
Convert
(
flags
.
get
()
);
break
;
}
case
FmkType
::
FmkType_CAFFE
:
{
CaffeConverter
caffeConverter
;
fb_graph
=
caffeConverter
.
Convert
(
flags
);
fb_graph
=
caffeConverter
.
Convert
(
flags
.
get
()
);
}
break
;
case
FmkType
::
FmkType_TFLITE
:
{
TfliteConverter
tfLiteConverter
;
fb_graph
=
tfLiteConverter
.
Convert
(
flags
);
fb_graph
=
tfLiteConverter
.
Convert
(
flags
.
get
()
);
}
break
;
case
FmkType
::
FmkType_ONNX
:
{
OnnxConverter
onnxConverter
;
fb_graph
=
onnxConverter
.
Convert
(
flags
);
fb_graph
=
onnxConverter
.
Convert
(
flags
.
get
()
);
}
break
;
default:
{
MS_LOG
(
ERROR
)
<<
"Unsupported fmkType: "
<<
flags
->
fmk
;
...
...
mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc
浏览文件 @
989a7e9a
...
...
@@ -26,7 +26,7 @@ void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::C
}
std
::
unique_ptr
<
schema
::
DepthwiseConv2DT
>
depthwiseConv2DParam
=
std
::
make_unique
<
schema
::
DepthwiseConv2DT
>
();
if
(
depthwiseConv2DParam
==
nullptr
)
{
// MS_LOGW("new DepthwiseConv2DT failed")
;
MS_LOG
(
ERROR
)
<<
"new DepthwiseConv2DT failed"
;
return
;
}
depthwiseConv2DParam
->
format
=
attr
->
format
;
...
...
@@ -53,8 +53,11 @@ void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::C
STATUS
CaffeConvolutionParser
::
Parse
(
const
caffe
::
LayerParameter
&
proto
,
const
caffe
::
LayerParameter
&
weight
,
schema
::
CNodeT
*
op
,
std
::
vector
<
schema
::
TensorT
*>
*
weightVec
)
{
op
->
name
=
proto
.
name
();
schema
::
Conv2DT
*
attr
=
new
schema
::
Conv2DT
();
std
::
unique_ptr
<
schema
::
Conv2DT
>
attr
(
new
(
std
::
nothrow
)
schema
::
Conv2DT
());
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new Conv2DT failed"
;
return
RET_ERROR
;
}
attr
->
format
=
schema
::
Format_NCHW
;
const
caffe
::
ConvolutionParameter
convParam
=
proto
.
convolution_param
();
...
...
@@ -118,9 +121,9 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c
attr
->
padMode
=
schema
::
PadMode_CAFFE
;
op
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
op
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
op
->
primitive
->
value
.
value
=
attr
;
op
->
primitive
->
value
.
value
=
attr
.
get
()
;
ParseGroupConvolution
(
op
,
attr
);
ParseGroupConvolution
(
op
,
attr
.
release
()
);
status
=
convParser
.
ParseWeight
(
weight
,
weightVec
);
if
(
status
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ParseWeight for "
<<
proto
.
name
().
c_str
()
<<
" failed"
;
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
浏览文件 @
989a7e9a
...
...
@@ -159,7 +159,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
auto
isConst
=
(
!
tensor_buffer
->
data
.
empty
());
if
(
isConst
)
{
CopyConstTensorData
(
tflite_model_buffer
,
tflite_tensor
.
get
(),
tensor
.
get
());
}
else
if
(
tensor
->
dataType
==
TypeId
::
kNumberTypeUInt8
)
{
}
else
if
(
tensor
->
dataType
==
TypeId
::
kNumberTypeUInt8
)
{
// set in/out tensor to int8 to fit ms-lite op
tensor
->
dataType
=
TypeId
::
kNumberTypeInt8
;
}
...
...
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
浏览文件 @
989a7e9a
...
...
@@ -103,11 +103,13 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
const
float
stdValue
=
std
::
stof
(
stdValues
,
&
sz
);
sz
=
0
;
const
float
mean
=
std
::
stof
(
meanValues
,
&
sz
);
std
::
unique_ptr
<
InputArray
>
inArr
=
nullptr
;
if
(
inputInferType
==
"FLOAT"
)
{
mInputArray
=
new
InputArray
(
mean
,
stdValue
);
inArr
.
reset
(
new
(
std
::
nothrow
)
InputArray
(
mean
,
stdValue
)
);
}
else
{
mInputArray
=
new
InputArray
(
mean
,
stdValue
,
TypeId
::
kNumberTypeUInt8
);
inArr
.
reset
(
new
(
std
::
nothrow
)
InputArray
(
mean
,
stdValue
,
TypeId
::
kNumberTypeUInt8
)
);
}
mInputArray
=
inArr
.
get
();
mInputArray
->
InitQuantParam
();
}
...
...
@@ -527,9 +529,9 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
// quant bias data
auto
bShapeSize
=
GetShapeSize
(
*
(
biasTensor
.
get
()));
auto
*
qDatas
=
new
(
std
::
nothrow
)
int32_t
[
bShapeSize
]
;
std
::
unique_ptr
<
int32_t
[]
>
qDatas
(
new
(
std
::
nothrow
)
int32_t
[
bShapeSize
])
;
if
(
qDatas
==
nullptr
)
{
// MS_LOGE("new qDatas failed")
;
MS_LOG
(
ERROR
)
<<
"new qDatas failed"
;
return
RET_ERROR
;
}
void
*
biasData
=
biasTensor
->
data
.
data
();
...
...
@@ -541,13 +543,11 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph,
biasTensor
->
data
.
clear
();
biasTensor
->
data
.
resize
(
bShapeSize
*
sizeof
(
int32_t
));
auto
ret
=
memcpy_s
(
biasTensor
->
data
.
data
(),
bShapeSize
*
sizeof
(
int32_t
),
qDatas
,
bShapeSize
*
sizeof
(
int32_t
));
qDatas
.
get
()
,
bShapeSize
*
sizeof
(
int32_t
));
if
(
ret
!=
EOK
)
{
// MS_LOGE("memcpy_s failed: %d", ret);
delete
[]
qDatas
;
MS_LOG
(
ERROR
)
<<
"memcpy_s failed: "
<<
ret
;
return
RET_ERROR
;
}
delete
[]
qDatas
;
return
RET_OK
;
}
...
...
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
浏览文件 @
989a7e9a
...
...
@@ -441,50 +441,62 @@ class CalcActivation : public QuantParamCalcer {
}
}
};
QuantParamCalcRegister
::~
QuantParamCalcRegister
()
{
for
(
auto
ite
:
_registerMap
)
{
if
(
ite
.
second
!=
nullptr
)
{
delete
ite
.
second
;
ite
.
second
=
nullptr
;
}
}
}
QuantParamCalcRegister
::
QuantParamCalcRegister
()
{
bool
hasError
=
false
;
auto
baseCalcer
=
new
(
std
::
nothrow
)
QuantParamCalcer
(
);
std
::
unique_ptr
<
QuantParamCalcer
>
baseCalcer
(
new
(
std
::
nothrow
)
QuantParamCalcer
()
);
if
(
baseCalcer
==
nullptr
)
{
// MS_LOGW("new QuantParamCalcer failed")
;
MS_LOG
(
ERROR
)
<<
"new QuantParamCalcer failed"
;
hasError
=
true
;
}
auto
commonCalcer
=
new
(
std
::
nothrow
)
CommonCalcer
(
);
std
::
unique_ptr
<
CommonCalcer
>
commonCalcer
(
new
(
std
::
nothrow
)
CommonCalcer
()
);
if
(
commonCalcer
==
nullptr
)
{
// MS_LOGW("new commonCalcer failed")
;
MS_LOG
(
ERROR
)
<<
"new commonCalcer failed"
;
hasError
=
true
;
}
auto
linearCalcer
=
new
(
std
::
nothrow
)
LinearCalcer
();
std
::
unique_ptr
<
LinearCalcer
>
linearCalcer
(
new
(
std
::
nothrow
)
LinearCalcer
());
if
(
linearCalcer
==
nullptr
)
{
// MS_LOGW("new linearCalcer failed")
;
MS_LOG
(
ERROR
)
<<
"new linearCalcer failed"
;
hasError
=
true
;
}
if
(
!
hasError
)
{
_registerMap
[
schema
::
PrimitiveType_Concat
]
=
new
CalcConcat
();
_registerMap
[
schema
::
PrimitiveType_Activation
]
=
new
CalcActivation
();
_registerMap
[
schema
::
PrimitiveType_Add
]
=
new
CalcAdd
();
_registerMap
[
schema
::
PrimitiveType_Mul
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_Conv2D
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_DepthwiseConv2D
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_Pooling
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Resize
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Reshape
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Shape
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Mul
]
=
commonCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_Conv2D
]
=
commonCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_DepthwiseConv2D
]
=
commonCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_Pooling
]
=
linearCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_Resize
]
=
linearCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_Reshape
]
=
linearCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_Shape
]
=
linearCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_SoftMax
]
=
new
CalcToSet
(
0
,
1
);
_registerMap
[
schema
::
PrimitiveType_Squeeze
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Squeeze
]
=
linearCalcer
.
get
()
;
_registerMap
[
schema
::
PrimitiveType_RealDiv
]
=
new
CalcRealDiv
();
_registerMap
[
schema
::
PrimitiveType_Reduce
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_BiasAdd
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_Mean
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Transpose
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_MatMul
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_FullConnection
]
=
commonCalcer
;
_registerMap
[
schema
::
PrimitiveType_Nchw2Nhwc
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Nhwc2Nchw
]
=
linearCalcer
;
_registerMap
[
schema
::
PrimitiveType_Reduce
]
=
commonCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_BiasAdd
]
=
commonCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_Mean
]
=
linearCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_Transpose
]
=
linearCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_MatMul
]
=
commonCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_FullConnection
]
=
commonCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_Nchw2Nhwc
]
=
linearCalcer
.
get
();
_registerMap
[
schema
::
PrimitiveType_Nhwc2Nchw
]
=
linearCalcer
.
get
();
// todo
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
_registerMap
[
schema
::
PrimitiveType_DetectionPostProcess
]
=
baseCalcer
;
_registerMap
[
schema
::
PrimitiveType_DetectionPostProcess
]
=
baseCalcer
.
get
();
baseCalcer
.
release
();
linearCalcer
.
release
();
commonCalcer
.
release
();
}
}
...
...
mindspore/lite/tools/converter/quantizer/calc_quant_param.h
浏览文件 @
989a7e9a
...
...
@@ -55,7 +55,7 @@ class LinearCalcer : public QuantParamCalcer {
class
QuantParamCalcRegister
{
public:
virtual
~
QuantParamCalcRegister
()
=
default
;
virtual
~
QuantParamCalcRegister
();
QuantParamCalcer
*
GetQuantParamCalcer
(
schema
::
PrimitiveType
opType
);
static
QuantParamCalcRegister
*
GetInstance
();
...
...
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
浏览文件 @
989a7e9a
...
...
@@ -55,14 +55,18 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
if
(
lite_tensor_size
==
0
)
{
return
input_tensors
;
}
auto
tensor_data
=
new
(
std
::
nothrow
)
char
[
lite_tensor_size
/
sizeof
(
char
)];
auto
tensor_data
=
new
(
std
::
nothrow
)
char
[
lite_tensor_size
/
sizeof
(
char
)];
if
(
tensor_data
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"tensor_data is nullptr"
;
delete
lite_tensor
;
return
input_tensors
;
}
auto
ret
=
memcpy_s
(
tensor_data
,
lite_tensor_size
,
tensorT
->
data
.
data
(),
lite_tensor_size
);
if
(
ret
!=
EOK
)
{
delete
lite_tensor
;
delete
tensor_data
;
MS_LOG
(
EXCEPTION
)
<<
"memcpy error: "
<<
ret
;
return
input_tensors
;
}
lite_tensor
->
SetData
(
tensor_data
);
input_tensors
.
emplace_back
(
lite_tensor
);
...
...
@@ -111,7 +115,9 @@ const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *ten
}
auto
ret
=
memcpy_s
(
tensor_data
,
size
*
sizeof
(
float
),
tensor
->
Data
(),
size
*
sizeof
(
float
));
if
(
ret
!=
EOK
)
{
delete
tensor_data
;
MS_LOG
(
EXCEPTION
)
<<
"memcpy error: "
<<
ret
;
return
parameter
;
}
param_value
->
set_tensor_addr
(
tensor_data
);
param_value
->
set_tensor_size
(
size
*
sizeof
(
float
)
/
sizeof
(
uint8_t
));
...
...
@@ -138,7 +144,17 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
return
nullptr
;
}
}
// namespace
void
FreeInputTensor
(
std
::
vector
<
Tensor
*>
*
input_tensor
)
{
MS_ASSERT
(
input_tensor
!=
nullptr
);
for
(
size_t
i
=
0
;
i
<
input_tensor
->
size
();
i
++
)
{
if
((
*
input_tensor
)[
i
]
==
nullptr
)
{
continue
;
}
delete
(
*
input_tensor
)[
i
];
(
*
input_tensor
)[
i
]
=
nullptr
;
}
return
;
}
const
AnfNodePtr
ConstFoldPass
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
CheckIfFuncGraphIsNull
(
func_graph
);
...
...
@@ -154,6 +170,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
auto
input_cnode
=
input_node
->
cast
<
CNodePtr
>
();
auto
input_tensors
=
GetCNodeInputTensors
(
input_cnode
);
if
(
input_tensors
.
empty
()
||
input_tensors
.
size
()
!=
input_cnode
->
inputs
().
size
()
-
1
)
{
FreeInputTensor
(
&
input_tensors
);
return
any_node
;
}
MS_LOG
(
INFO
)
<<
"Begin fold node:"
<<
input_node
->
fullname_with_scope
();
...
...
@@ -163,21 +180,25 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
auto
lite_primitive
=
mindspore
::
lite
::
PrimitiveC
::
CreatePrimitive
(
scheam_primitive
);
if
(
lite_primitive
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"constant_folding schedule node lite primitive nullptr"
;
FreeInputTensor
(
&
input_tensors
);
return
nullptr
;
}
lite_primitive
->
InferShape
(
input_tensors
,
output_tensors
);
auto
lite_kernel
=
GetLiteKernel
(
input_tensors
,
output_tensors
,
lite_primitive
);
if
(
lite_kernel
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"constant_folding schedule node lite kernel nullptr"
;
FreeInputTensor
(
&
input_tensors
);
return
nullptr
;
}
auto
ret
=
lite_kernel
->
Run
();
if
(
0
!=
ret
)
{
FreeInputTensor
(
&
input_tensors
);
MS_LOG
(
EXCEPTION
)
<<
"run kernel failed, name: "
<<
lite_kernel
->
name
();
}
auto
new_parameter
=
CreateNewParamter
(
func_graph
,
output_tensors
.
front
());
new_parameter
->
set_name
(
input_node
->
fullname_with_scope
());
any_node
->
set_input
(
i
,
new_parameter
);
FreeInputTensor
(
&
input_tensors
);
}
}
return
any_node
;
...
...
mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc
浏览文件 @
989a7e9a
...
...
@@ -73,7 +73,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
auto
abstr
=
transform_node
->
abstract
();
int
kernel_nums
=
Get_Kenrnel_nums
(
conv_node
);
if
(
kernel_nums
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Unsupported conv node, "
<<
conv_node
->
DebugString
();
MS_LOG
(
INFO
)
<<
"Unsupported conv node, "
<<
conv_node
->
DebugString
();
return
node
;
}
auto
trans_scale
=
new
(
std
::
nothrow
)
float
[
kernel_nums
];
...
...
@@ -84,6 +84,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
auto
trans_bias
=
new
(
std
::
nothrow
)
float
[
kernel_nums
];
if
(
trans_bias
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"tensor_data is nullptr"
;
delete
trans_scale
;
return
nullptr
;
}
GenTransParam
(
transform_node
,
kernel_nums
,
trans_scale
,
trans_bias
);
...
...
@@ -164,7 +165,7 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph,
bias_flag
=
true
;
}
else
{
bias_data
=
new
(
std
::
nothrow
)
float
[
kernel_num
];
if
(
trans_scale
==
nullptr
)
{
if
(
bias_data
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"tensor_data is nullptr"
;
return
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录