Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
475543aa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
475543aa
编写于
8月 18, 2020
作者:
C
chenjianping
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support infer datatype and format when shape infer fail
上级
9d55ac62
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
96 addition
and
54 deletion
+96
-54
mindspore/lite/src/ops/arithmetic.cc
mindspore/lite/src/ops/arithmetic.cc
+6
-2
mindspore/lite/src/ops/arithmetic_self.cc
mindspore/lite/src/ops/arithmetic_self.cc
+4
-2
mindspore/lite/src/ops/batch_to_space.cc
mindspore/lite/src/ops/batch_to_space.cc
+5
-2
mindspore/lite/src/ops/concat.cc
mindspore/lite/src/ops/concat.cc
+5
-3
mindspore/lite/src/ops/conv.cc
mindspore/lite/src/ops/conv.cc
+6
-2
mindspore/lite/src/ops/gather.cc
mindspore/lite/src/ops/gather.cc
+5
-3
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+4
-1
mindspore/lite/src/ops/reduce.cc
mindspore/lite/src/ops/reduce.cc
+5
-2
mindspore/lite/src/ops/reshape.cc
mindspore/lite/src/ops/reshape.cc
+5
-3
mindspore/lite/src/ops/shape.cc
mindspore/lite/src/ops/shape.cc
+9
-12
mindspore/lite/src/ops/slice.cc
mindspore/lite/src/ops/slice.cc
+5
-3
mindspore/lite/src/ops/softmax.cc
mindspore/lite/src/ops/softmax.cc
+4
-1
mindspore/lite/src/ops/transpose.cc
mindspore/lite/src/ops/transpose.cc
+5
-3
mindspore/lite/src/ops/unsqueeze.cc
mindspore/lite/src/ops/unsqueeze.cc
+5
-2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
...e/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
+1
-0
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
+16
-13
mindspore/lite/src/scheduler.cc
mindspore/lite/src/scheduler.cc
+6
-0
未找到文件。
mindspore/lite/src/ops/arithmetic.cc
浏览文件 @
475543aa
...
...
@@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto
input_shape0
=
input0
->
shape
();
auto
input_shape1
=
input1
->
shape
();
auto
format
=
input0
->
GetFormat
();
output
->
SetFormat
(
format
);
output
->
set_data_type
(
input0
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
in_shape0_
.
resize
(
5
);
in_shape1_
.
resize
(
5
);
out_shape_
.
resize
(
5
);
...
...
@@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape
.
push_back
(
out_shape_
[
i
]);
}
output
->
SetFormat
(
format
);
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
...
...
mindspore/lite/src/ops/arithmetic_self.cc
浏览文件 @
475543aa
...
...
@@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/batch_to_space.cc
浏览文件 @
475543aa
...
...
@@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG
(
ERROR
)
<<
"batch_to_space only support NHWC now!"
;
return
RET_FORMAT_ERR
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape
[
kNHWC_w_index
]
=
input_shape
[
kNHWC_w_index
]
*
block_shape
->
Get
(
1
)
-
crops
->
Get
(
2
)
-
crops
->
Get
(
3
);
output_shape
[
kNHWC_c_index
]
=
input_shape
[
kNHWC_c_index
];
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/concat.cc
浏览文件 @
475543aa
...
...
@@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG
(
ERROR
)
<<
"output size is error"
;
return
RET_PARAM_INVALID
;
}
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
concat_prim
=
this
->
primitive
->
value_as_Concat
();
MS_ASSERT
(
concat_prim
!=
nullptr
);
auto
input0_shape
=
inputs_
.
at
(
0
)
->
shape
();
...
...
@@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
auto
output_shape
=
input0_shape
;
output_shape
[
axis
]
=
output_axis_dim
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/conv.cc
浏览文件 @
475543aa
...
...
@@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input_tensor
!=
nullptr
);
MS_ASSERT
(
out_tensor
!=
nullptr
);
out_tensor
->
SetFormat
(
input_tensor
->
GetFormat
());
out_tensor
->
set_data_type
(
input_tensor
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input_tensor
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
out_shape
.
at
(
2
)
=
output_w
;
out_shape
.
at
(
3
)
=
weight_tensor
->
shape
()[
0
];
out_tensor
->
set_shape
(
out_shape
);
out_tensor
->
SetFormat
(
input_tensor
->
GetFormat
());
out_tensor
->
set_data_type
(
input_tensor
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
...
...
mindspore/lite/src/ops/gather.cc
浏览文件 @
475543aa
...
...
@@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
input
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
gather_prim
=
this
->
primitive
->
value_as_Gather
();
MS_ASSERT
(
gather_prim
!=
nullptr
);
...
...
@@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/ops.cc
浏览文件 @
475543aa
...
...
@@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/reduce.cc
浏览文件 @
475543aa
...
...
@@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
input
==
nullptr
||
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
this
->
primitive
==
nullptr
)
{
return
RET_NULL_PTR
;
}
...
...
@@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/reshape.cc
浏览文件 @
475543aa
...
...
@@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
reshape_prim
=
this
->
primitive
->
value_as_Reshape
();
MS_ASSERT
(
reshape_prim
!=
nullptr
);
...
...
@@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/shape.cc
浏览文件 @
475543aa
...
...
@@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
in_tensor
=
inputs_
.
front
();
auto
out_tensor
=
outputs_
.
front
();
auto
ret_dtype
=
out_tensor
->
set_data_type
(
kNumberTypeInt32
);
if
(
ret_dtype
!=
in_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Set datatype fails."
;
return
RET_ERROR
;
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
out_shape
;
out_shape
.
push_back
(
static_cast
<
int
>
(
in_tensor
->
shape
().
size
()));
...
...
@@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG
(
ERROR
)
<<
"Set shape fails."
;
return
RET_ERROR
;
}
auto
ret_dtype
=
out_tensor
->
set_data_type
(
in_tensor
->
data_type
());
if
(
ret_dtype
!=
in_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Set datatype fails."
;
return
RET_ERROR
;
}
// todo
// auto ret_data = out_tensor->MallocData();
// if (ret_data != 0) {
// MS_LOG(ERROR) << "Allocate memory fails.";
// return RET_ERROR;
// }
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/slice.cc
浏览文件 @
475543aa
...
...
@@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
auto
slice_prim
=
this
->
primitive
->
value_as_Slice
();
std
::
vector
<
int32_t
>
slice_begin
(
slice_prim
->
begin
()
->
begin
(),
slice_prim
->
begin
()
->
end
());
...
...
@@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/softmax.cc
浏览文件 @
475543aa
...
...
@@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/transpose.cc
浏览文件 @
475543aa
...
...
@@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
MS_ASSERT
(
inputs_
.
size
()
==
kSingleNum
);
MS_ASSERT
(
outputs_
.
size
()
==
kSingleNum
);
auto
transpore_prim
=
this
->
primitive
->
value_as_Transpose
();
...
...
@@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/unsqueeze.cc
浏览文件 @
475543aa
...
...
@@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
if
(
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"output size is invalid"
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
unsqueeze_prim
=
this
->
primitive
->
value_as_Unsqueeze
();
auto
dims
=
unsqueeze_prim
->
axis
()
->
data
();
auto
in_shape
=
input
->
shape
();
...
...
@@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
浏览文件 @
475543aa
...
...
@@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
}
if
(
pack_input_
!=
nullptr
)
{
free
(
pack_input_
);
pack_input_
=
nullptr
;
}
}
bool
pre_trans_input_
=
false
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
浏览文件 @
475543aa
...
...
@@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
kernel
::
LiteKernel
*
CpuConvFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
op
P
arameter
,
const
Context
*
ctx
,
OpParameter
*
op
_p
arameter
,
const
Context
*
ctx
,
const
kernel
::
KernelKey
&
desc
,
const
lite
::
Primitive
*
primitive
)
{
MS_ASSERT
(
op
P
arameter
!=
nullptr
);
MS_ASSERT
(
op
_p
arameter
!=
nullptr
);
MS_ASSERT
(
desc
.
type
==
schema
::
PrimitiveType_Conv2D
);
auto
conv_param
=
reinterpret_cast
<
ConvParameter
*>
(
op
P
arameter
);
auto
conv_param
=
reinterpret_cast
<
ConvParameter
*>
(
op
_p
arameter
);
int
kernel_h
=
conv_param
->
kernel_h_
;
int
kernel_w
=
conv_param
->
kernel_w_
;
int
stride_h
=
conv_param
->
stride_h_
;
...
...
@@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
conv_param
->
output_h_
=
outputs
.
front
()
->
Height
();
conv_param
->
output_w_
=
outputs
.
front
()
->
Width
();
bool
use_winograd
=
false
;
bool
use_sw
=
false
;
int
out_unit
;
InputTransformUnitFunc
input_trans_func
=
nullptr
;
OutputTransformUnitFunc
output_trans_func
=
nullptr
;
CheckIfUseWinograd
(
&
use_winograd
,
&
out_unit
,
conv_param
,
input_trans_func
,
output_trans_func
);
bool
use_sw
=
CheckIfUseSlideWindow
(
conv_param
);
if
(
primitive
!=
nullptr
&&
primitive
->
GetInferFlag
())
{
CheckIfUseWinograd
(
&
use_winograd
,
&
out_unit
,
conv_param
,
input_trans_func
,
output_trans_func
);
use_sw
=
CheckIfUseSlideWindow
(
conv_param
);
}
kernel
::
LiteKernel
*
kernel
;
if
(
kernel_h
==
1
&&
kernel_w
==
1
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution1x1CPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution1x1CPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
if
(
kernel_h
==
3
&&
kernel_w
==
3
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution3x3CPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
Convolution3x3CPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
if
(
use_winograd
)
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionWinogradCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
,
out_unit
);
new
(
std
::
nothrow
)
kernel
::
ConvolutionWinogradCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
,
out_unit
);
}
else
if
(
use_sw
)
{
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op
P
arameter, inputs, outputs, ctx, primitive);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op
_p
arameter, inputs, outputs, ctx, primitive);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
else
{
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
kernel
=
new
(
std
::
nothrow
)
kernel
::
ConvolutionCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
}
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"kernel is nullptr."
;
...
...
@@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
&&
ret
!=
RET_INFER_INVALID
)
{
delete
kernel
;
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
P
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
P
arameter
->
type_
));
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
_p
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
_p
arameter
->
type_
));
return
nullptr
;
}
return
kernel
;
...
...
mindspore/lite/src/scheduler.cc
浏览文件 @
475543aa
...
...
@@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
}
}
else
{
primitive
->
SetInferFlag
(
false
);
auto
ret
=
primitive
->
InferShape
(
inputs
,
outputs
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InferShape fail! name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_INFER_ERR
;
}
}
}
return
RET_OK
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录