Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c22b56e8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c22b56e8
编写于
8月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4645 [MS][LITE][Develop]support infer data type and format when infer shape fail
Merge pull request !4645 from chenjianping/lite_dev3
上级
28755b2f
475543aa
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
96 addition
and
54 deletion
+96
-54
mindspore/lite/src/ops/arithmetic.cc
mindspore/lite/src/ops/arithmetic.cc
+6
-2
mindspore/lite/src/ops/arithmetic_self.cc
mindspore/lite/src/ops/arithmetic_self.cc
+4
-2
mindspore/lite/src/ops/batch_to_space.cc
mindspore/lite/src/ops/batch_to_space.cc
+5
-2
mindspore/lite/src/ops/concat.cc
mindspore/lite/src/ops/concat.cc
+5
-3
mindspore/lite/src/ops/conv.cc
mindspore/lite/src/ops/conv.cc
+6
-2
mindspore/lite/src/ops/gather.cc
mindspore/lite/src/ops/gather.cc
+5
-3
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+4
-1
mindspore/lite/src/ops/reduce.cc
mindspore/lite/src/ops/reduce.cc
+5
-2
mindspore/lite/src/ops/reshape.cc
mindspore/lite/src/ops/reshape.cc
+5
-3
mindspore/lite/src/ops/shape.cc
mindspore/lite/src/ops/shape.cc
+9
-12
mindspore/lite/src/ops/slice.cc
mindspore/lite/src/ops/slice.cc
+5
-3
mindspore/lite/src/ops/softmax.cc
mindspore/lite/src/ops/softmax.cc
+4
-1
mindspore/lite/src/ops/transpose.cc
mindspore/lite/src/ops/transpose.cc
+5
-3
mindspore/lite/src/ops/unsqueeze.cc
mindspore/lite/src/ops/unsqueeze.cc
+5
-2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
...e/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
+1
-0
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
+16
-13
mindspore/lite/src/scheduler.cc
mindspore/lite/src/scheduler.cc
+6
-0
未找到文件。
mindspore/lite/src/ops/arithmetic.cc
浏览文件 @
c22b56e8
...
...
@@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto
input_shape0
=
input0
->
shape
();
auto
input_shape1
=
input1
->
shape
();
auto
format
=
input0
->
GetFormat
();
output
->
SetFormat
(
format
);
output
->
set_data_type
(
input0
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
in_shape0_
.
resize
(
5
);
in_shape1_
.
resize
(
5
);
out_shape_
.
resize
(
5
);
...
...
@@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape
.
push_back
(
out_shape_
[
i
]);
}
output
->
SetFormat
(
format
);
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
...
...
mindspore/lite/src/ops/arithmetic_self.cc
浏览文件 @
c22b56e8
...
...
@@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/batch_to_space.cc
浏览文件 @
c22b56e8
...
...
@@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG
(
ERROR
)
<<
"batch_to_space only support NHWC now!"
;
return
RET_FORMAT_ERR
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape
[
kNHWC_w_index
]
=
input_shape
[
kNHWC_w_index
]
*
block_shape
->
Get
(
1
)
-
crops
->
Get
(
2
)
-
crops
->
Get
(
3
);
output_shape
[
kNHWC_c_index
]
=
input_shape
[
kNHWC_c_index
];
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/concat.cc
浏览文件 @
c22b56e8
...
...
@@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG
(
ERROR
)
<<
"output size is error"
;
return
RET_PARAM_INVALID
;
}
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
concat_prim
=
this
->
primitive
->
value_as_Concat
();
MS_ASSERT
(
concat_prim
!=
nullptr
);
auto
input0_shape
=
inputs_
.
at
(
0
)
->
shape
();
...
...
@@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
auto
output_shape
=
input0_shape
;
output_shape
[
axis
]
=
output_axis_dim
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/conv.cc
浏览文件 @
c22b56e8
...
...
@@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input_tensor
!=
nullptr
);
MS_ASSERT
(
out_tensor
!=
nullptr
);
out_tensor
->
SetFormat
(
input_tensor
->
GetFormat
());
out_tensor
->
set_data_type
(
input_tensor
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input_tensor
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
out_shape
.
at
(
2
)
=
output_w
;
out_shape
.
at
(
3
)
=
weight_tensor
->
shape
()[
0
];
out_tensor
->
set_shape
(
out_shape
);
out_tensor
->
SetFormat
(
input_tensor
->
GetFormat
());
out_tensor
->
set_data_type
(
input_tensor
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
...
...
mindspore/lite/src/ops/gather.cc
浏览文件 @
c22b56e8
...
...
@@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
input
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
gather_prim
=
this
->
primitive
->
value_as_Gather
();
MS_ASSERT
(
gather_prim
!=
nullptr
);
...
...
@@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/ops.cc
浏览文件 @
c22b56e8
...
...
@@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/reduce.cc
浏览文件 @
c22b56e8
...
...
@@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
input
==
nullptr
||
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
this
->
primitive
==
nullptr
)
{
return
RET_NULL_PTR
;
}
...
...
@@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/reshape.cc
浏览文件 @
c22b56e8
...
...
@@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
reshape_prim
=
this
->
primitive
->
value_as_Reshape
();
MS_ASSERT
(
reshape_prim
!=
nullptr
);
...
...
@@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/shape.cc
浏览文件 @
c22b56e8
...
...
@@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
in_tensor
=
inputs_
.
front
();
auto
out_tensor
=
outputs_
.
front
();
auto
ret_dtype
=
out_tensor
->
set_data_type
(
kNumberTypeInt32
);
if
(
ret_dtype
!=
in_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Set datatype fails."
;
return
RET_ERROR
;
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
out_shape
;
out_shape
.
push_back
(
static_cast
<
int
>
(
in_tensor
->
shape
().
size
()));
...
...
@@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG
(
ERROR
)
<<
"Set shape fails."
;
return
RET_ERROR
;
}
auto
ret_dtype
=
out_tensor
->
set_data_type
(
in_tensor
->
data_type
());
if
(
ret_dtype
!=
in_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Set datatype fails."
;
return
RET_ERROR
;
}
// todo
// auto ret_data = out_tensor->MallocData();
// if (ret_data != 0) {
// MS_LOG(ERROR) << "Allocate memory fails.";
// return RET_ERROR;
// }
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/slice.cc
浏览文件 @
c22b56e8
...
...
@@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
auto
slice_prim
=
this
->
primitive
->
value_as_Slice
();
std
::
vector
<
int32_t
>
slice_begin
(
slice_prim
->
begin
()
->
begin
(),
slice_prim
->
begin
()
->
end
());
...
...
@@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/softmax.cc
浏览文件 @
c22b56e8
...
...
@@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/transpose.cc
浏览文件 @
c22b56e8
...
...
@@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
MS_ASSERT
(
inputs_
.
size
()
==
kSingleNum
);
MS_ASSERT
(
outputs_
.
size
()
==
kSingleNum
);
auto
transpore_prim
=
this
->
primitive
->
value_as_Transpose
();
...
...
@@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/unsqueeze.cc
浏览文件 @
c22b56e8
...
...
@@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
if
(
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"output size is invalid"
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
unsqueeze_prim
=
this
->
primitive
->
value_as_Unsqueeze
();
auto
dims
=
unsqueeze_prim
->
axis
()
->
data
();
auto
in_shape
=
input
->
shape
();
...
...
@@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
浏览文件 @
c22b56e8
...
...
@@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
}
if
(
pack_input_
!=
nullptr
)
{
free
(
pack_input_
);
pack_input_
=
nullptr
;
}
}
bool
pre_trans_input_
=
false
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
浏览文件 @
c22b56e8
...
...
@@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
kernel
::
LiteKernel
*
CpuConvFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
op
P
arameter
,
const
Context
*
ctx
,
OpParameter
*
op
_p
arameter
,
const
Context
*
ctx
,
const
kernel
::
KernelKey
&
desc
,
const
lite
::
Primitive
*
primitive
)
{
MS_ASSERT
(
op
P
arameter
!=
nullptr
);
MS_ASSERT
(
op
_p
arameter
!=
nullptr
);
MS_ASSERT
(
desc
.
type
==
schema
::
PrimitiveType_Conv2D
);
auto
conv_param
=
reinterpret_cast
<
ConvParameter
*>
(
op
P
arameter
);
auto
conv_param
=
reinterpret_cast
<
ConvParameter
*>
(
op
_p
arameter
);
int
kernel_h
=
conv_param
->
kernel_h_
;
int
kernel_w
=
conv_param
->
kernel_w_
;
int
stride_h
=
conv_param
->
stride_h_
;
...
...
@@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
conv_param
->
output_h_
=
outputs
.
front
()
->
Height
();
conv_param
->
output_w_
=
outputs
.
front
()
->
Width
();
bool
use_winograd
=
false
;
bool
use_sw
=
false
;
int
out_unit
;
InputTransformUnitFunc
input_trans_func
=
nullptr
;
OutputTransformUnitFunc
output_trans_func
=
nullptr
;
CheckIfUseWinograd
(
&
use_winograd
,
&
out_unit
,
conv_param
,
input_trans_func
,
output_trans_func
);
bool
use_sw
=
CheckIfUseSlideWindow
(
conv_param
);
if
(
primitive
!=
nullptr
&&
primitive
->
GetInferFlag
())
{
CheckIfUseWinograd
(
&
use_winograd
,
&
out_unit
,
conv_param
,
input_trans_func
,
output_trans_func
);
use_sw
=
CheckIfUseSlideWindow
(
conv_param
);
}
kernel
::
LiteKernel
*
kernel
;
if
(
kernel_h
==
1
&&
kernel_w
==
1
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution1x1CPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution1x1CPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
if
(
kernel_h
==
3
&&
kernel_w
==
3
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution3x3CPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution3x3CPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
if
(
use_winograd
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionWinogradCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
,
out_unit
);
new
(
std
::
nothrow
)
kernel
::
ConvolutionWinogradCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
,
out_unit
);
}
else
if
(
use_sw
)
{
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op
P
arameter, inputs, outputs, ctx, primitive);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op
_p
arameter, inputs, outputs, ctx, primitive);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"kernel is nullptr."
;
...
...
@@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
&&
ret
!=
RET_INFER_INVALID
)
{
delete
kernel
;
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
P
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
P
arameter
->
type_
));
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
_p
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
_p
arameter
->
type_
));
return
nullptr
;
}
return
kernel
;
...
...
mindspore/lite/src/scheduler.cc
浏览文件 @
c22b56e8
...
...
@@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
}
}
else
{
primitive
->
SetInferFlag
(
false
);
auto
ret
=
primitive
->
InferShape
(
inputs
,
outputs
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InferShape fail! name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_INFER_ERR
;
}
}
}
return
RET_OK
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录