Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
030af09f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
030af09f
编写于
8月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4665 [MS][LITE][Develop]support infer data type and format when infer shape fail
Merge pull request !4665 from chenjianping/lite_dev3
上级
63e0ad24
2f36b91a
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
298 addition
and
126 deletion
+298
-126
mindspore/lite/src/ops/addn.cc
mindspore/lite/src/ops/addn.cc
+6
-2
mindspore/lite/src/ops/argmax.cc
mindspore/lite/src/ops/argmax.cc
+7
-2
mindspore/lite/src/ops/argmin.cc
mindspore/lite/src/ops/argmin.cc
+6
-2
mindspore/lite/src/ops/broadcast_to.cc
mindspore/lite/src/ops/broadcast_to.cc
+6
-3
mindspore/lite/src/ops/cast.cc
mindspore/lite/src/ops/cast.cc
+7
-6
mindspore/lite/src/ops/constant_of_shape.cc
mindspore/lite/src/ops/constant_of_shape.cc
+6
-3
mindspore/lite/src/ops/crop.cc
mindspore/lite/src/ops/crop.cc
+4
-1
mindspore/lite/src/ops/deconv2d.cc
mindspore/lite/src/ops/deconv2d.cc
+5
-3
mindspore/lite/src/ops/dedepthwise_conv2d.cc
mindspore/lite/src/ops/dedepthwise_conv2d.cc
+5
-3
mindspore/lite/src/ops/depth_to_space.cc
mindspore/lite/src/ops/depth_to_space.cc
+6
-4
mindspore/lite/src/ops/depthwise_conv2d.cc
mindspore/lite/src/ops/depthwise_conv2d.cc
+5
-3
mindspore/lite/src/ops/embedding_lookup.cc
mindspore/lite/src/ops/embedding_lookup.cc
+6
-1
mindspore/lite/src/ops/expand_dims.cc
mindspore/lite/src/ops/expand_dims.cc
+5
-2
mindspore/lite/src/ops/fill.cc
mindspore/lite/src/ops/fill.cc
+5
-2
mindspore/lite/src/ops/flatten.cc
mindspore/lite/src/ops/flatten.cc
+7
-2
mindspore/lite/src/ops/full_connection.cc
mindspore/lite/src/ops/full_connection.cc
+5
-3
mindspore/lite/src/ops/gather_nd.cc
mindspore/lite/src/ops/gather_nd.cc
+6
-2
mindspore/lite/src/ops/lstm.cc
mindspore/lite/src/ops/lstm.cc
+9
-4
mindspore/lite/src/ops/matmul.cc
mindspore/lite/src/ops/matmul.cc
+7
-2
mindspore/lite/src/ops/mean.cc
mindspore/lite/src/ops/mean.cc
+5
-2
mindspore/lite/src/ops/nchw2nhwc.cc
mindspore/lite/src/ops/nchw2nhwc.cc
+5
-2
mindspore/lite/src/ops/nhwc2nchw.cc
mindspore/lite/src/ops/nhwc2nchw.cc
+5
-2
mindspore/lite/src/ops/one_hot.cc
mindspore/lite/src/ops/one_hot.cc
+13
-10
mindspore/lite/src/ops/pad.cc
mindspore/lite/src/ops/pad.cc
+10
-6
mindspore/lite/src/ops/pooling.cc
mindspore/lite/src/ops/pooling.cc
+5
-3
mindspore/lite/src/ops/power.cc
mindspore/lite/src/ops/power.cc
+6
-2
mindspore/lite/src/ops/prior_box.cc
mindspore/lite/src/ops/prior_box.cc
+9
-6
mindspore/lite/src/ops/quant_dtype_cast.cc
mindspore/lite/src/ops/quant_dtype_cast.cc
+4
-1
mindspore/lite/src/ops/range.cc
mindspore/lite/src/ops/range.cc
+8
-2
mindspore/lite/src/ops/rank.cc
mindspore/lite/src/ops/rank.cc
+5
-2
mindspore/lite/src/ops/resize.cc
mindspore/lite/src/ops/resize.cc
+6
-3
mindspore/lite/src/ops/reverse_sequence.cc
mindspore/lite/src/ops/reverse_sequence.cc
+5
-1
mindspore/lite/src/ops/roi_pooling.cc
mindspore/lite/src/ops/roi_pooling.cc
+5
-2
mindspore/lite/src/ops/scatter_nd.cc
mindspore/lite/src/ops/scatter_nd.cc
+5
-2
mindspore/lite/src/ops/space_to_batch.cc
mindspore/lite/src/ops/space_to_batch.cc
+6
-2
mindspore/lite/src/ops/space_to_depth.cc
mindspore/lite/src/ops/space_to_depth.cc
+6
-2
mindspore/lite/src/ops/split.cc
mindspore/lite/src/ops/split.cc
+7
-0
mindspore/lite/src/ops/squeeze.cc
mindspore/lite/src/ops/squeeze.cc
+5
-2
mindspore/lite/src/ops/stack.cc
mindspore/lite/src/ops/stack.cc
+5
-2
mindspore/lite/src/ops/strided_slice.cc
mindspore/lite/src/ops/strided_slice.cc
+5
-2
mindspore/lite/src/ops/tile.cc
mindspore/lite/src/ops/tile.cc
+6
-2
mindspore/lite/src/ops/topk.cc
mindspore/lite/src/ops/topk.cc
+7
-4
mindspore/lite/src/ops/unique.cc
mindspore/lite/src/ops/unique.cc
+5
-2
mindspore/lite/src/ops/unstack.cc
mindspore/lite/src/ops/unstack.cc
+8
-2
mindspore/lite/src/ops/where.cc
mindspore/lite/src/ops/where.cc
+5
-2
mindspore/lite/src/ops/zeros_like.cc
mindspore/lite/src/ops/zeros_like.cc
+4
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
...pore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
+20
-6
未找到文件。
mindspore/lite/src/ops/addn.cc
浏览文件 @
030af09f
...
...
@@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG
(
ERROR
)
<<
"input size"
<<
inputs
.
size
()
<<
" is error!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
for
(
int
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs
.
at
(
i
)
->
shape
()
!=
inputs
.
at
(
0
)
->
shape
())
{
MS_LOG
(
ERROR
)
<<
"AddN inputs shape is not equal!"
;
...
...
@@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return
RET_INPUT_TENSOR_ERROR
;
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/argmax.cc
浏览文件 @
030af09f
...
...
@@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
inputs_
.
size
()
!=
kSingleNum
||
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
argmax_prim
=
this
->
primitive
->
value_as_ArgMax
();
std
::
vector
<
int
>
output_shape
(
input
->
shape
());
auto
input_shape_size
=
input
->
shape
().
size
();
...
...
@@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
else
{
output_shape
[
axis
]
=
argmax_prim
->
topK
();
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/argmin.cc
浏览文件 @
030af09f
...
...
@@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
inputs_
.
size
()
!=
kSingleNum
||
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
argmin_prim
=
this
->
primitive
->
value_as_ArgMin
();
auto
input_shape_size
=
input
->
shape
().
size
();
int
axis
=
argmin_prim
->
axis
()
<
0
?
argmin_prim
->
axis
()
+
input_shape_size
:
argmin_prim
->
axis
();
...
...
@@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
else
{
output_shape
[
axis
]
=
argmin_prim
->
topK
();
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/broadcast_to.cc
浏览文件 @
030af09f
...
...
@@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
return
1
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int32_t
>
dst_shape
(
this
->
primitive
->
value_as_BroadcastTo
()
->
dst_shape
()
->
begin
(),
this
->
primitive
->
value_as_BroadcastTo
()
->
dst_shape
()
->
end
());
auto
input_shape
=
input
->
shape
();
...
...
@@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
shape
[
i
]
=
dst_shape
[
i
];
--
input_shape_index
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/cast.cc
浏览文件 @
030af09f
...
...
@@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
auto
cast_prim
=
this
->
primitive
->
value_as_Cast
();
MS_ASSERT
(
cast_prim
!=
nullptr
);
output
->
set_data_type
(
static_cast
<
TypeId
>
(
cast_prim
->
dstT
()));
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
input
->
data_type
()
!=
cast_prim
->
srcT
())
{
MS_LOG
(
ERROR
)
<<
"input dataType is error"
;
return
RET_INPUT_TENSOR_ERROR
;
...
...
@@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"Unsupported input data type "
<<
input
->
data_type
();
return
RET_INPUT_TENSOR_ERROR
;
}
if
(
cast_prim
->
dstT
()
!=
kNumberTypeFloat
&&
cast_prim
->
dstT
()
!=
kNumberTypeFloat32
)
{
MS_LOG
(
ERROR
)
<<
"Invalid output datatype "
<<
cast_prim
->
dstT
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
TypeId
::
kNumberTypeFloat32
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/constant_of_shape.cc
浏览文件 @
030af09f
...
...
@@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
return
RET_ERROR
;
}
auto
in_tensor
=
inputs_
.
front
();
auto
in_data
=
reinterpret_cast
<
int
*>
(
in_tensor
->
Data
());
auto
out_tensor
=
outputs_
.
front
();
out_tensor
->
set_data_type
(
kNumberTypeFloat32
);
out_tensor
->
SetFormat
(
in_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_data
=
reinterpret_cast
<
int
*>
(
in_tensor
->
Data
());
int
size
=
in_tensor
->
ElementsNum
();
std
::
vector
<
int
>
out_shape
(
size
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
out_shape
[
i
]
=
in_data
[
i
];
}
out_tensor
->
set_shape
(
out_shape
);
out_tensor
->
set_data_type
(
kNumberTypeFloat32
);
out_tensor
->
SetFormat
(
in_tensor
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/crop.cc
浏览文件 @
030af09f
...
...
@@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG
(
ERROR
)
<<
"Invalid output/input size! output size: "
<<
outputs
.
size
()
<<
",input size: "
<<
inputs
.
size
();
return
RET_PARAM_INVALID
;
}
outputs
[
0
]
->
set_shape
(
inputs
[
1
]
->
shape
());
outputs
[
0
]
->
SetFormat
(
inputs
[
0
]
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
inputs
[
0
]
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
outputs
[
0
]
->
set_shape
(
inputs
[
1
]
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/deconv2d.cc
浏览文件 @
030af09f
...
...
@@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int32_t
input_h
=
input
->
Height
();
int32_t
input_w
=
input
->
Width
();
...
...
@@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
std
::
vector
<
int
>
out_shape
=
{
output_n
,
output_h
,
output_w
,
output_c
};
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/dedepthwise_conv2d.cc
浏览文件 @
030af09f
...
...
@@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
at
(
3
)
=
weight
->
shape
()[
0
]
*
weight
->
shape
()[
3
];
// in_channel * out_channel
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/depth_to_space.cc
浏览文件 @
030af09f
...
...
@@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"depth_to_space only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
*
block_size
;
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
]
/
(
block_size
*
block_size
);
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/depthwise_conv2d.cc
浏览文件 @
030af09f
...
...
@@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
at
(
3
)
=
weight
->
shape
()[
0
]
*
weight
->
shape
()[
3
];
// in_channel * out_channel
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/embedding_lookup.cc
浏览文件 @
030af09f
...
...
@@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
ids
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
params_
->
GetFormat
());
output
->
set_data_type
(
params_
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
embedding_shape
=
params_
->
shape
();
embedding_shape
.
erase
(
embedding_shape
.
begin
());
std
::
vector
<
int
>
output_shape
(
ids
->
shape
());
...
...
@@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
}
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
params_
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/expand_dims.cc
浏览文件 @
030af09f
...
...
@@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if
(
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"output size is invalid"
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
expand_dims_prim
=
this
->
primitive
->
value_as_ExpandDims
();
int
dim
=
expand_dims_prim
->
dim
();
if
(
dim
<
0
)
{
...
...
@@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto
out_shape
=
input
->
shape
();
out_shape
.
insert
(
out_shape
.
begin
()
+
dim
,
1
,
1
);
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/fill.cc
浏览文件 @
030af09f
...
...
@@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"input size: "
<<
inputs_
.
size
()
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
fill_prim
=
this
->
primitive
->
value_as_Fill
();
if
(
fill_prim
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Fill primitive is null!"
;
...
...
@@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std
::
vector
<
int
>
output_shape
;
(
void
)
output_shape
.
insert
(
output_shape
.
begin
(),
fill_prim
->
dims
()
->
begin
(),
fill_prim
->
dims
()
->
end
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/flatten.cc
浏览文件 @
030af09f
...
...
@@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_LOG
(
ERROR
)
<<
"input size: "
<<
inputs_
.
size
()
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
(
2
);
output_shape
[
0
]
=
input_shape
[
0
];
...
...
@@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output_shape
[
1
]
*=
input_shape
[
i
];
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/full_connection.cc
浏览文件 @
030af09f
...
...
@@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
input1
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
((
GetHasBias
()
&&
inputs_
.
size
()
!=
kMultiNum
)
||
(
!
GetHasBias
()
&&
inputs_
.
size
()
!=
kDoubleNum
))
{
MS_LOG
(
ERROR
)
<<
"Input tensors num error"
;
return
1
;
...
...
@@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
resize
(
GetAxis
()
+
1
);
out_shape
[
GetAxis
()]
=
input1
->
shape
()[
0
];
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
0
;
}
...
...
mindspore/lite/src/ops/gather_nd.cc
浏览文件 @
030af09f
...
...
@@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
MS_ASSERT
(
indices
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
in_rank
=
in_shape
.
size
();
auto
indices_shape
=
indices
->
shape
();
...
...
@@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
out_shape
.
emplace_back
(
in_shape
[
i
]);
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/lstm.cc
浏览文件 @
030af09f
...
...
@@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input0
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
for
(
int
i
=
0
;
i
<
kLstmOutputNum
;
i
++
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
in_shape
=
input
->
shape
();
std
::
vector
<
int
>
w_shape
=
weight_i
->
shape
();
// layer, hidden_size * 4, input_size
if
(
in_shape
.
size
()
!=
3
||
w_shape
.
size
()
!=
3
)
{
...
...
@@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
state_shape
[
2
]
=
hidden_size
;
outputs_
[
1
]
->
set_shape
(
state_shape
);
outputs_
[
2
]
->
set_shape
(
state_shape
);
for
(
int
i
=
0
;
i
<
kLstmOutputNum
;
i
++
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/matmul.cc
浏览文件 @
030af09f
...
...
@@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input1
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
a_shape
=
input0
->
shape
();
std
::
vector
<
int
>
b_shape
=
input1
->
shape
();
if
(
a_shape
.
size
()
<
2
||
b_shape
.
size
()
<
2
)
{
...
...
@@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std
::
vector
<
int
>
c_shape
(
a_shape
);
c_shape
[
c_shape
.
size
()
-
1
]
=
b_shape
[
b_shape
.
size
()
-
1
];
output
->
set_shape
(
c_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/mean.cc
浏览文件 @
030af09f
...
...
@@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if
(
input
==
nullptr
||
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
this
->
primitive
==
nullptr
)
{
return
RET_NULL_PTR
;
}
...
...
@@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
}
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/nchw2nhwc.cc
浏览文件 @
030af09f
...
...
@@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
schema
::
Format_NHWC
);
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
nchw_shape
=
input
->
shape
();
if
(
nchw_shape
.
size
()
!=
4
)
{
output
->
set_shape
(
nchw_shape
);
...
...
@@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nhwc_shape
[
NHWC_C
]
=
nchw_shape
[
NCHW_C
];
output
->
set_shape
(
nhwc_shape
);
}
output
->
SetFormat
(
schema
::
Format_NHWC
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/nhwc2nchw.cc
浏览文件 @
030af09f
...
...
@@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
schema
::
Format_NCHW
);
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
nhwc_shape
=
input
->
shape
();
if
(
nhwc_shape
.
size
()
!=
4
)
{
output
->
set_shape
(
nhwc_shape
);
...
...
@@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nchw_shape
[
NCHW_W
]
=
nhwc_shape
[
NHWC_W
];
output
->
set_shape
(
nchw_shape
);
}
output
->
SetFormat
(
schema
::
Format_NCHW
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/one_hot.cc
浏览文件 @
030af09f
...
...
@@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
if
(
input
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
on_value
=
inputs
.
at
(
2
);
if
(
on_value
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
on_value
->
data_type
());
output
->
SetFormat
(
on_value
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
const
auto
input_shape
=
input
->
shape
();
int
input_rank
=
static_cast
<
int
>
(
input_shape
.
size
());
if
(
axis
<
0
)
{
...
...
@@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
}
std
::
vector
<
int
>
output_shape
(
input_shape
);
output_shape
.
insert
(
output_shape
.
cbegin
()
+
axis
,
*
depth
);
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_shape
(
output_shape
);
auto
on_value
=
inputs
.
at
(
2
);
if
(
on_value
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
on_value
->
data_type
());
output
->
SetFormat
(
on_value
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/pad.cc
浏览文件 @
030af09f
...
...
@@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if
(
input
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
;
MS_ASSERT
(
input
->
shape
().
size
()
<=
kInputRank
);
...
...
@@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
auto
shape
=
input_shape
[
i
]
+
(
*
paddings
)[
2
*
paddings_index
]
+
(
*
paddings
)[
2
*
paddings_index
+
1
];
output_shape
.
push_back
(
shape
);
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/pooling.cc
浏览文件 @
030af09f
...
...
@@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
schema
::
Format_NHWC
);
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
input_h
=
input
->
shape
().
at
(
1
);
int
input_w
=
input
->
shape
().
at
(
2
);
auto
pooling_prim
=
this
->
primitive
->
value_as_Pooling
();
...
...
@@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape
.
at
(
1
)
=
output_h
;
input_shape
.
at
(
2
)
=
output_w
;
output
->
set_shape
(
input_shape
);
output
->
set_data_type
(
input
->
data_type
());
// todo: temp fix
output
->
SetFormat
(
schema
::
Format_NHWC
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/power.cc
浏览文件 @
030af09f
...
...
@@ -49,15 +49,19 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
auto
output_tensor
=
outputs
[
0
];
MS_ASSERT
(
output_tensor
!=
nullptr
);
output_tensor
->
set_data_type
(
x_tensor
->
data_type
());
output_tensor
->
SetFormat
(
x_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
exp_tensor
!=
nullptr
)
{
if
(
exp_tensor
->
shape
()
!=
x_tensor
->
shape
()
||
exp_tensor
->
data_type
()
!=
x_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Power inputs shape or type is not equal!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
}
output_tensor
->
SetFormat
(
x_tensor
->
GetFormat
());
output_tensor
->
set_shape
(
x_tensor
->
shape
());
output_tensor
->
set_data_type
(
x_tensor
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/prior_box.cc
浏览文件 @
030af09f
...
...
@@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2;
int
PriorBox
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
auto
param
=
this
->
primitive
->
value_as_PriorBox
();
MS_ASSERT
(
param
!=
nullptr
);
auto
input
=
inputs_
.
at
(
0
);
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
at
(
0
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
kNumberTypeFloat32
);
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
float
>
different_aspect_ratios
{
1.0
f
};
auto
aspect_ratios
=
param
->
aspect_ratios
();
MS_ASSERT
(
aspect_ratios
!=
nullptr
);
...
...
@@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
}
}
int32_t
num_priors_box
=
param
->
min_sizes
()
->
size
()
*
different_aspect_ratios
.
size
()
+
param
->
max_sizes
()
->
size
();
auto
input
=
inputs_
.
at
(
0
);
MS_ASSERT
(
input
!=
nullptr
);
int32_t
h
=
input
->
Height
()
*
input
->
Width
()
*
num_priors_box
*
kPriorBoxPoints
;
std
::
vector
<
int
>
output_shape
{
kPriorBoxN
,
h
,
kPriorBoxW
,
kPriorBoxC
};
auto
output
=
outputs_
.
at
(
0
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
kNumberTypeFloat32
);
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/quant_dtype_cast.cc
浏览文件 @
030af09f
...
...
@@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
auto
param
=
primitive
->
value_as_QuantDTypeCast
();
MS_ASSERT
(
input
->
data_type
()
==
param
->
srcT
);
output
->
set_data_type
(
static_cast
<
TypeId
>
(
param
->
dstT
()));
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/range.cc
浏览文件 @
030af09f
...
...
@@ -50,12 +50,18 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_ASSERT
(
output
!=
nullptr
);
auto
range_prim
=
this
->
primitive
->
value_as_Range
();
MS_ASSERT
(
range_prim
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
shape_size
=
std
::
ceil
(
static_cast
<
float
>
(
range_prim
->
limit
()
-
range_prim
->
start
())
/
range_prim
->
delta
());
std
::
vector
<
int
>
in_shape
(
1
);
in_shape
.
push_back
(
shape_size
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/rank.cc
浏览文件 @
030af09f
...
...
@@ -25,10 +25,13 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
std
::
vector
<
int
>
in_shape
(
1
,
1
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
in_shape
(
1
,
1
);
output
->
set_shape
(
in_shape
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/resize.cc
浏览文件 @
030af09f
...
...
@@ -66,6 +66,11 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
if
(
output
==
nullptr
)
{
return
1
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
new_height
=
GetNewHeight
();
auto
new_width
=
GetNewWidth
();
...
...
@@ -75,10 +80,8 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
output_shape
.
push_back
(
new_width
);
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/reverse_sequence.cc
浏览文件 @
030af09f
...
...
@@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector<tensor::Tensor *> inputs, std::vecto
auto
output
=
outputs
.
front
();
MS_ASSERT
(
input
!=
nullptr
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/roi_pooling.cc
浏览文件 @
030af09f
...
...
@@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
ROIPooling
=
this
->
primitive
->
value_as_ROIPooling
();
auto
new_h
=
ROIPooling
->
pooledH
();
auto
new_w
=
ROIPooling
->
pooledW
();
...
...
@@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
output_shape
.
push_back
(
new_w
);
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/scatter_nd.cc
浏览文件 @
030af09f
...
...
@@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
return
RET_ERROR
;
}
auto
output
=
outputs_
.
front
();
output
->
set_data_type
(
update
->
data_type
());
output
->
SetFormat
(
update
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
shape_data
=
reinterpret_cast
<
int
*>
(
shape
->
Data
());
std
::
vector
<
int
>
out_shape
(
shape_data
,
shape_data
+
shape
->
DataSize
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
update
->
data_type
());
output
->
SetFormat
(
update
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/space_to_batch.cc
浏览文件 @
030af09f
...
...
@@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"space_to_batch only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
/
block_sizes_
[
NHWC_H
];
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
];
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/space_to_depth.cc
浏览文件 @
030af09f
...
...
@@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"space_to_depth only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
/
block_size
;
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
]
*
(
block_size
*
block_size
);
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/split.cc
浏览文件 @
030af09f
...
...
@@ -66,6 +66,13 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG
(
ERROR
)
<<
"outputs number is not equal to "
<<
number_split
;
return
RET_ERROR
;
}
for
(
int
i
=
0
;
i
<
number_split
;
++
i
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
split_dim
=
spilt_prim
->
splitDim
();
std
::
vector
<
int
>
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
size_split
;
...
...
mindspore/lite/src/ops/squeeze.cc
浏览文件 @
030af09f
...
...
@@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
return
-
1
;
}
auto
*
in_tensor
=
inputs_
.
front
();
outputs_
.
front
()
->
set_data_type
(
in_tensor
->
data_type
());
outputs_
.
front
()
->
SetFormat
(
in_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
in_tensor
->
shape
();
std
::
vector
<
int
>
out_shape
;
// todo: getAxis
...
...
@@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
}
outputs_
.
front
()
->
set_shape
(
out_shape
);
outputs_
.
front
()
->
set_data_type
(
in_tensor
->
data_type
());
outputs_
.
front
()
->
SetFormat
(
in_tensor
->
GetFormat
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/stack.cc
浏览文件 @
030af09f
...
...
@@ -56,6 +56,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
auto
stack_prim
=
this
->
primitive
->
value_as_Stack
();
std
::
vector
<
int32_t
>
output_shape
=
input_shape
;
...
...
@@ -84,8 +89,6 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
output_shape
.
insert
(
output_shape
.
begin
()
+
axis
,
inputs
.
size
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/strided_slice.cc
浏览文件 @
030af09f
...
...
@@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
.
front
()
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
MS_ASSERT
(
input
!=
nullptr
);
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
;
...
...
@@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
=
ApplyShrinkMask
(
output_shape
);
outputs
.
front
()
->
set_shape
(
output_shape
);
outputs
.
front
()
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/tile.cc
浏览文件 @
030af09f
...
...
@@ -40,6 +40,11 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
tile_prim
=
this
->
primitive
->
value_as_Tile
();
MS_ASSERT
(
tile_prim
!=
nullptr
);
std
::
vector
<
int
>
out_shape
;
...
...
@@ -49,9 +54,8 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
int
tmp
=
input
->
shape
()[
i
]
*
multiples
[
i
];
out_shape
.
push_back
(
tmp
);
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/topk.cc
浏览文件 @
030af09f
...
...
@@ -46,16 +46,19 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
output0
!=
nullptr
);
auto
output1
=
outputs_
.
at
(
1
);
MS_ASSERT
(
output1
!=
nullptr
);
output0
->
set_data_type
(
input
->
data_type
());
output0
->
SetFormat
(
input
->
GetFormat
());
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
topk_prim
=
this
->
primitive
->
value_as_TopK
();
MS_ASSERT
(
topk_prim
!=
nullptr
);
auto
out_shape
=
input
->
shape
();
out_shape
[
out_shape
.
size
()
-
1
]
=
topk_prim
->
k
();
output0
->
set_shape
(
out_shape
);
output0
->
set_data_type
(
input
->
data_type
());
output0
->
SetFormat
(
input
->
GetFormat
());
output1
->
set_shape
(
out_shape
);
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/unique.cc
浏览文件 @
030af09f
...
...
@@ -42,12 +42,15 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
output0
!=
nullptr
);
auto
&
output1
=
outputs_
.
at
(
1
);
MS_ASSERT
(
output1
!=
nullptr
);
output0
->
set_shape
(
input
->
shape
());
output0
->
set_data_type
(
input
->
data_type
());
output1
->
set_shape
(
input
->
shape
());
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
output0
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output0
->
set_shape
(
input
->
shape
());
output1
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/unstack.cc
浏览文件 @
030af09f
...
...
@@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
MS_LOG
(
ERROR
)
<<
"Invalid axis "
<<
prim
->
axis
();
return
RET_PARAM_INVALID
;
}
for
(
auto
&
out
:
outputs
)
{
MS_ASSERT
(
out
!=
nullptr
);
out
->
set_data_type
(
input
->
data_type
());
out
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
output_shape
;
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
if
(
i
!=
axis
)
{
...
...
@@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
for
(
auto
&
out
:
outputs
)
{
MS_ASSERT
(
out
!=
nullptr
);
out
->
set_shape
(
output_shape
);
out
->
set_data_type
(
input
->
data_type
());
out
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/where.cc
浏览文件 @
030af09f
...
...
@@ -53,6 +53,11 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
input0
=
inputs_
.
at
(
0
);
auto
input1
=
inputs_
.
at
(
1
);
auto
input2
=
inputs_
.
at
(
2
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
num
=
input0
->
ElementsNum
();
int
num1
=
input1
->
ElementsNum
();
int
num2
=
input2
->
ElementsNum
();
...
...
@@ -85,8 +90,6 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
output_shape
=
shape_tmp
;
output_shape
[
axisout
]
=
nummax
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/zeros_like.cc
浏览文件 @
030af09f
...
...
@@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
浏览文件 @
030af09f
...
...
@@ -18,15 +18,29 @@
#include <float.h>
int
ArgCompareAscFp32
(
const
void
*
a
,
const
void
*
b
)
{
return
((
ArgElement
*
)
a
)
->
data_
.
f_data_
-
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
float
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
float
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
if
(
b_value
>
a_value
)
{
return
-
1
;
}
if
(
b_value
<
a_value
)
{
return
1
;
}
return
0
;
}
int
ArgCompareDescFp32
(
const
void
*
a
,
const
void
*
b
)
{
// cmp funtion of qsort must return int type
auto
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
auto
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
int
res
=
b_value
>
a_value
?
1
:
-
1
;
return
res
;
float
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
float
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
if
(
b_value
>
a_value
)
{
return
1
;
}
if
(
b_value
<
a_value
)
{
return
-
1
;
}
return
0
;
}
void
ArgMaxDim0OutValue
(
const
float
*
input
,
float
*
output
,
const
int
*
in_shape
,
ArgMinMaxParameter
*
param
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录