Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2f36b91a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2f36b91a
编写于
8月 18, 2020
作者:
C
chenjianping
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support infer datatype and format when shape infer fail
上级
82e8884e
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
298 addition
and
126 deletion
+298
-126
mindspore/lite/src/ops/addn.cc
mindspore/lite/src/ops/addn.cc
+6
-2
mindspore/lite/src/ops/argmax.cc
mindspore/lite/src/ops/argmax.cc
+7
-2
mindspore/lite/src/ops/argmin.cc
mindspore/lite/src/ops/argmin.cc
+6
-2
mindspore/lite/src/ops/broadcast_to.cc
mindspore/lite/src/ops/broadcast_to.cc
+6
-3
mindspore/lite/src/ops/cast.cc
mindspore/lite/src/ops/cast.cc
+7
-6
mindspore/lite/src/ops/constant_of_shape.cc
mindspore/lite/src/ops/constant_of_shape.cc
+6
-3
mindspore/lite/src/ops/crop.cc
mindspore/lite/src/ops/crop.cc
+4
-1
mindspore/lite/src/ops/deconv2d.cc
mindspore/lite/src/ops/deconv2d.cc
+5
-3
mindspore/lite/src/ops/dedepthwise_conv2d.cc
mindspore/lite/src/ops/dedepthwise_conv2d.cc
+5
-3
mindspore/lite/src/ops/depth_to_space.cc
mindspore/lite/src/ops/depth_to_space.cc
+6
-4
mindspore/lite/src/ops/depthwise_conv2d.cc
mindspore/lite/src/ops/depthwise_conv2d.cc
+5
-3
mindspore/lite/src/ops/embedding_lookup.cc
mindspore/lite/src/ops/embedding_lookup.cc
+6
-1
mindspore/lite/src/ops/expand_dims.cc
mindspore/lite/src/ops/expand_dims.cc
+5
-2
mindspore/lite/src/ops/fill.cc
mindspore/lite/src/ops/fill.cc
+5
-2
mindspore/lite/src/ops/flatten.cc
mindspore/lite/src/ops/flatten.cc
+7
-2
mindspore/lite/src/ops/full_connection.cc
mindspore/lite/src/ops/full_connection.cc
+5
-3
mindspore/lite/src/ops/gather_nd.cc
mindspore/lite/src/ops/gather_nd.cc
+6
-2
mindspore/lite/src/ops/lstm.cc
mindspore/lite/src/ops/lstm.cc
+9
-4
mindspore/lite/src/ops/matmul.cc
mindspore/lite/src/ops/matmul.cc
+7
-2
mindspore/lite/src/ops/mean.cc
mindspore/lite/src/ops/mean.cc
+5
-2
mindspore/lite/src/ops/nchw2nhwc.cc
mindspore/lite/src/ops/nchw2nhwc.cc
+5
-2
mindspore/lite/src/ops/nhwc2nchw.cc
mindspore/lite/src/ops/nhwc2nchw.cc
+5
-2
mindspore/lite/src/ops/one_hot.cc
mindspore/lite/src/ops/one_hot.cc
+13
-10
mindspore/lite/src/ops/pad.cc
mindspore/lite/src/ops/pad.cc
+10
-6
mindspore/lite/src/ops/pooling.cc
mindspore/lite/src/ops/pooling.cc
+5
-3
mindspore/lite/src/ops/power.cc
mindspore/lite/src/ops/power.cc
+6
-2
mindspore/lite/src/ops/prior_box.cc
mindspore/lite/src/ops/prior_box.cc
+9
-6
mindspore/lite/src/ops/quant_dtype_cast.cc
mindspore/lite/src/ops/quant_dtype_cast.cc
+4
-1
mindspore/lite/src/ops/range.cc
mindspore/lite/src/ops/range.cc
+8
-2
mindspore/lite/src/ops/rank.cc
mindspore/lite/src/ops/rank.cc
+5
-2
mindspore/lite/src/ops/resize.cc
mindspore/lite/src/ops/resize.cc
+6
-3
mindspore/lite/src/ops/reverse_sequence.cc
mindspore/lite/src/ops/reverse_sequence.cc
+5
-1
mindspore/lite/src/ops/roi_pooling.cc
mindspore/lite/src/ops/roi_pooling.cc
+5
-2
mindspore/lite/src/ops/scatter_nd.cc
mindspore/lite/src/ops/scatter_nd.cc
+5
-2
mindspore/lite/src/ops/space_to_batch.cc
mindspore/lite/src/ops/space_to_batch.cc
+6
-2
mindspore/lite/src/ops/space_to_depth.cc
mindspore/lite/src/ops/space_to_depth.cc
+6
-2
mindspore/lite/src/ops/split.cc
mindspore/lite/src/ops/split.cc
+7
-0
mindspore/lite/src/ops/squeeze.cc
mindspore/lite/src/ops/squeeze.cc
+5
-2
mindspore/lite/src/ops/stack.cc
mindspore/lite/src/ops/stack.cc
+5
-2
mindspore/lite/src/ops/strided_slice.cc
mindspore/lite/src/ops/strided_slice.cc
+5
-2
mindspore/lite/src/ops/tile.cc
mindspore/lite/src/ops/tile.cc
+6
-2
mindspore/lite/src/ops/topk.cc
mindspore/lite/src/ops/topk.cc
+7
-4
mindspore/lite/src/ops/unique.cc
mindspore/lite/src/ops/unique.cc
+5
-2
mindspore/lite/src/ops/unstack.cc
mindspore/lite/src/ops/unstack.cc
+8
-2
mindspore/lite/src/ops/where.cc
mindspore/lite/src/ops/where.cc
+5
-2
mindspore/lite/src/ops/zeros_like.cc
mindspore/lite/src/ops/zeros_like.cc
+4
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
...pore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
+20
-6
未找到文件。
mindspore/lite/src/ops/addn.cc
浏览文件 @
2f36b91a
...
...
@@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG
(
ERROR
)
<<
"input size"
<<
inputs
.
size
()
<<
" is error!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
for
(
int
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs
.
at
(
i
)
->
shape
()
!=
inputs
.
at
(
0
)
->
shape
())
{
MS_LOG
(
ERROR
)
<<
"AddN inputs shape is not equal!"
;
...
...
@@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return
RET_INPUT_TENSOR_ERROR
;
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/argmax.cc
浏览文件 @
2f36b91a
...
...
@@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
inputs_
.
size
()
!=
kSingleNum
||
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
argmax_prim
=
this
->
primitive
->
value_as_ArgMax
();
std
::
vector
<
int
>
output_shape
(
input
->
shape
());
auto
input_shape_size
=
input
->
shape
().
size
();
...
...
@@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
else
{
output_shape
[
axis
]
=
argmax_prim
->
topK
();
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/argmin.cc
浏览文件 @
2f36b91a
...
...
@@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if
(
inputs_
.
size
()
!=
kSingleNum
||
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
argmin_prim
=
this
->
primitive
->
value_as_ArgMin
();
auto
input_shape_size
=
input
->
shape
().
size
();
int
axis
=
argmin_prim
->
axis
()
<
0
?
argmin_prim
->
axis
()
+
input_shape_size
:
argmin_prim
->
axis
();
...
...
@@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
else
{
output_shape
[
axis
]
=
argmin_prim
->
topK
();
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/broadcast_to.cc
浏览文件 @
2f36b91a
...
...
@@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
return
1
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int32_t
>
dst_shape
(
this
->
primitive
->
value_as_BroadcastTo
()
->
dst_shape
()
->
begin
(),
this
->
primitive
->
value_as_BroadcastTo
()
->
dst_shape
()
->
end
());
auto
input_shape
=
input
->
shape
();
...
...
@@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
shape
[
i
]
=
dst_shape
[
i
];
--
input_shape_index
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/cast.cc
浏览文件 @
2f36b91a
...
...
@@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"tensor number is error."
;
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
auto
cast_prim
=
this
->
primitive
->
value_as_Cast
();
MS_ASSERT
(
cast_prim
!=
nullptr
);
output
->
set_data_type
(
static_cast
<
TypeId
>
(
cast_prim
->
dstT
()));
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
input
->
data_type
()
!=
cast_prim
->
srcT
())
{
MS_LOG
(
ERROR
)
<<
"input dataType is error"
;
return
RET_INPUT_TENSOR_ERROR
;
...
...
@@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"Unsupported input data type "
<<
input
->
data_type
();
return
RET_INPUT_TENSOR_ERROR
;
}
if
(
cast_prim
->
dstT
()
!=
kNumberTypeFloat
&&
cast_prim
->
dstT
()
!=
kNumberTypeFloat32
)
{
MS_LOG
(
ERROR
)
<<
"Invalid output datatype "
<<
cast_prim
->
dstT
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
TypeId
::
kNumberTypeFloat32
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/constant_of_shape.cc
浏览文件 @
2f36b91a
...
...
@@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
return
RET_ERROR
;
}
auto
in_tensor
=
inputs_
.
front
();
auto
in_data
=
reinterpret_cast
<
int
*>
(
in_tensor
->
Data
());
auto
out_tensor
=
outputs_
.
front
();
out_tensor
->
set_data_type
(
kNumberTypeFloat32
);
out_tensor
->
SetFormat
(
in_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_data
=
reinterpret_cast
<
int
*>
(
in_tensor
->
Data
());
int
size
=
in_tensor
->
ElementsNum
();
std
::
vector
<
int
>
out_shape
(
size
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
out_shape
[
i
]
=
in_data
[
i
];
}
out_tensor
->
set_shape
(
out_shape
);
out_tensor
->
set_data_type
(
kNumberTypeFloat32
);
out_tensor
->
SetFormat
(
in_tensor
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/crop.cc
浏览文件 @
2f36b91a
...
...
@@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG
(
ERROR
)
<<
"Invalid output/input size! output size: "
<<
outputs
.
size
()
<<
",input size: "
<<
inputs
.
size
();
return
RET_PARAM_INVALID
;
}
outputs
[
0
]
->
set_shape
(
inputs
[
1
]
->
shape
());
outputs
[
0
]
->
SetFormat
(
inputs
[
0
]
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
inputs
[
0
]
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
outputs
[
0
]
->
set_shape
(
inputs
[
1
]
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/deconv2d.cc
浏览文件 @
2f36b91a
...
...
@@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int32_t
input_h
=
input
->
Height
();
int32_t
input_w
=
input
->
Width
();
...
...
@@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
std
::
vector
<
int
>
out_shape
=
{
output_n
,
output_h
,
output_w
,
output_c
};
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/dedepthwise_conv2d.cc
浏览文件 @
2f36b91a
...
...
@@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
at
(
3
)
=
weight
->
shape
()[
0
]
*
weight
->
shape
()[
3
];
// in_channel * out_channel
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/depth_to_space.cc
浏览文件 @
2f36b91a
...
...
@@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"depth_to_space only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
*
block_size
;
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
]
/
(
block_size
*
block_size
);
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/depthwise_conv2d.cc
浏览文件 @
2f36b91a
...
...
@@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
weight
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
input_h
=
in_shape
.
at
(
1
);
int
input_w
=
in_shape
.
at
(
2
);
...
...
@@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
at
(
3
)
=
weight
->
shape
()[
0
]
*
weight
->
shape
()[
3
];
// in_channel * out_channel
output
->
set_shape
(
out_shape
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/embedding_lookup.cc
浏览文件 @
2f36b91a
...
...
@@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
ids
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
params_
->
GetFormat
());
output
->
set_data_type
(
params_
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
embedding_shape
=
params_
->
shape
();
embedding_shape
.
erase
(
embedding_shape
.
begin
());
std
::
vector
<
int
>
output_shape
(
ids
->
shape
());
...
...
@@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
}
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
params_
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/expand_dims.cc
浏览文件 @
2f36b91a
...
...
@@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if
(
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"output size is invalid"
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
expand_dims_prim
=
this
->
primitive
->
value_as_ExpandDims
();
int
dim
=
expand_dims_prim
->
dim
();
if
(
dim
<
0
)
{
...
...
@@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto
out_shape
=
input
->
shape
();
out_shape
.
insert
(
out_shape
.
begin
()
+
dim
,
1
,
1
);
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/fill.cc
浏览文件 @
2f36b91a
...
...
@@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"input size: "
<<
inputs_
.
size
()
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
fill_prim
=
this
->
primitive
->
value_as_Fill
();
if
(
fill_prim
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Fill primitive is null!"
;
...
...
@@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std
::
vector
<
int
>
output_shape
;
(
void
)
output_shape
.
insert
(
output_shape
.
begin
(),
fill_prim
->
dims
()
->
begin
(),
fill_prim
->
dims
()
->
end
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/flatten.cc
浏览文件 @
2f36b91a
...
...
@@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_LOG
(
ERROR
)
<<
"input size: "
<<
inputs_
.
size
()
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
(
2
);
output_shape
[
0
]
=
input_shape
[
0
];
...
...
@@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output_shape
[
1
]
*=
input_shape
[
i
];
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/full_connection.cc
浏览文件 @
2f36b91a
...
...
@@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT
(
input1
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
((
GetHasBias
()
&&
inputs_
.
size
()
!=
kMultiNum
)
||
(
!
GetHasBias
()
&&
inputs_
.
size
()
!=
kDoubleNum
))
{
MS_LOG
(
ERROR
)
<<
"Input tensors num error"
;
return
1
;
...
...
@@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape
.
resize
(
GetAxis
()
+
1
);
out_shape
[
GetAxis
()]
=
input1
->
shape
()[
0
];
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
0
;
}
...
...
mindspore/lite/src/ops/gather_nd.cc
浏览文件 @
2f36b91a
...
...
@@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
MS_ASSERT
(
indices
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
input
->
shape
();
int
in_rank
=
in_shape
.
size
();
auto
indices_shape
=
indices
->
shape
();
...
...
@@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
out_shape
.
emplace_back
(
in_shape
[
i
]);
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/lstm.cc
浏览文件 @
2f36b91a
...
...
@@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input0
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
for
(
int
i
=
0
;
i
<
kLstmOutputNum
;
i
++
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
in_shape
=
input
->
shape
();
std
::
vector
<
int
>
w_shape
=
weight_i
->
shape
();
// layer, hidden_size * 4, input_size
if
(
in_shape
.
size
()
!=
3
||
w_shape
.
size
()
!=
3
)
{
...
...
@@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
state_shape
[
2
]
=
hidden_size
;
outputs_
[
1
]
->
set_shape
(
state_shape
);
outputs_
[
2
]
->
set_shape
(
state_shape
);
for
(
int
i
=
0
;
i
<
kLstmOutputNum
;
i
++
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/matmul.cc
浏览文件 @
2f36b91a
...
...
@@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
input1
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
a_shape
=
input0
->
shape
();
std
::
vector
<
int
>
b_shape
=
input1
->
shape
();
if
(
a_shape
.
size
()
<
2
||
b_shape
.
size
()
<
2
)
{
...
...
@@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std
::
vector
<
int
>
c_shape
(
a_shape
);
c_shape
[
c_shape
.
size
()
-
1
]
=
b_shape
[
b_shape
.
size
()
-
1
];
output
->
set_shape
(
c_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/mean.cc
浏览文件 @
2f36b91a
...
...
@@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if
(
input
==
nullptr
||
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
this
->
primitive
==
nullptr
)
{
return
RET_NULL_PTR
;
}
...
...
@@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
}
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/nchw2nhwc.cc
浏览文件 @
2f36b91a
...
...
@@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
schema
::
Format_NHWC
);
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
nchw_shape
=
input
->
shape
();
if
(
nchw_shape
.
size
()
!=
4
)
{
output
->
set_shape
(
nchw_shape
);
...
...
@@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nhwc_shape
[
NHWC_C
]
=
nchw_shape
[
NCHW_C
];
output
->
set_shape
(
nhwc_shape
);
}
output
->
SetFormat
(
schema
::
Format_NHWC
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/nhwc2nchw.cc
浏览文件 @
2f36b91a
...
...
@@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
schema
::
Format_NCHW
);
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
nhwc_shape
=
input
->
shape
();
if
(
nhwc_shape
.
size
()
!=
4
)
{
output
->
set_shape
(
nhwc_shape
);
...
...
@@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nchw_shape
[
NCHW_W
]
=
nhwc_shape
[
NHWC_W
];
output
->
set_shape
(
nchw_shape
);
}
output
->
SetFormat
(
schema
::
Format_NCHW
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/one_hot.cc
浏览文件 @
2f36b91a
...
...
@@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
if
(
input
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
on_value
=
inputs
.
at
(
2
);
if
(
on_value
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
on_value
->
data_type
());
output
->
SetFormat
(
on_value
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
const
auto
input_shape
=
input
->
shape
();
int
input_rank
=
static_cast
<
int
>
(
input_shape
.
size
());
if
(
axis
<
0
)
{
...
...
@@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
}
std
::
vector
<
int
>
output_shape
(
input_shape
);
output_shape
.
insert
(
output_shape
.
cbegin
()
+
axis
,
*
depth
);
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_shape
(
output_shape
);
auto
on_value
=
inputs
.
at
(
2
);
if
(
on_value
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
on_value
->
data_type
());
output
->
SetFormat
(
on_value
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/pad.cc
浏览文件 @
2f36b91a
...
...
@@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if
(
input
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
;
MS_ASSERT
(
input
->
shape
().
size
()
<=
kInputRank
);
...
...
@@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
auto
shape
=
input_shape
[
i
]
+
(
*
paddings
)[
2
*
paddings_index
]
+
(
*
paddings
)[
2
*
paddings_index
+
1
];
output_shape
.
push_back
(
shape
);
}
auto
output
=
outputs
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/pooling.cc
浏览文件 @
2f36b91a
...
...
@@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
schema
::
Format_NHWC
);
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
input_h
=
input
->
shape
().
at
(
1
);
int
input_w
=
input
->
shape
().
at
(
2
);
auto
pooling_prim
=
this
->
primitive
->
value_as_Pooling
();
...
...
@@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape
.
at
(
1
)
=
output_h
;
input_shape
.
at
(
2
)
=
output_w
;
output
->
set_shape
(
input_shape
);
output
->
set_data_type
(
input
->
data_type
());
// todo: temp fix
output
->
SetFormat
(
schema
::
Format_NHWC
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/power.cc
浏览文件 @
2f36b91a
...
...
@@ -49,15 +49,19 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
auto
output_tensor
=
outputs
[
0
];
MS_ASSERT
(
output_tensor
!=
nullptr
);
output_tensor
->
set_data_type
(
x_tensor
->
data_type
());
output_tensor
->
SetFormat
(
x_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
if
(
exp_tensor
!=
nullptr
)
{
if
(
exp_tensor
->
shape
()
!=
x_tensor
->
shape
()
||
exp_tensor
->
data_type
()
!=
x_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Power inputs shape or type is not equal!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
}
output_tensor
->
SetFormat
(
x_tensor
->
GetFormat
());
output_tensor
->
set_shape
(
x_tensor
->
shape
());
output_tensor
->
set_data_type
(
x_tensor
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/prior_box.cc
浏览文件 @
2f36b91a
...
...
@@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2;
int
PriorBox
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
auto
param
=
this
->
primitive
->
value_as_PriorBox
();
MS_ASSERT
(
param
!=
nullptr
);
auto
input
=
inputs_
.
at
(
0
);
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
at
(
0
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
kNumberTypeFloat32
);
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
float
>
different_aspect_ratios
{
1.0
f
};
auto
aspect_ratios
=
param
->
aspect_ratios
();
MS_ASSERT
(
aspect_ratios
!=
nullptr
);
...
...
@@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
}
}
int32_t
num_priors_box
=
param
->
min_sizes
()
->
size
()
*
different_aspect_ratios
.
size
()
+
param
->
max_sizes
()
->
size
();
auto
input
=
inputs_
.
at
(
0
);
MS_ASSERT
(
input
!=
nullptr
);
int32_t
h
=
input
->
Height
()
*
input
->
Width
()
*
num_priors_box
*
kPriorBoxPoints
;
std
::
vector
<
int
>
output_shape
{
kPriorBoxN
,
h
,
kPriorBoxW
,
kPriorBoxC
};
auto
output
=
outputs_
.
at
(
0
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
kNumberTypeFloat32
);
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/quant_dtype_cast.cc
浏览文件 @
2f36b91a
...
...
@@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
auto
param
=
primitive
->
value_as_QuantDTypeCast
();
MS_ASSERT
(
input
->
data_type
()
==
param
->
srcT
);
output
->
set_data_type
(
static_cast
<
TypeId
>
(
param
->
dstT
()));
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/range.cc
浏览文件 @
2f36b91a
...
...
@@ -50,12 +50,18 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_ASSERT
(
output
!=
nullptr
);
auto
range_prim
=
this
->
primitive
->
value_as_Range
();
MS_ASSERT
(
range_prim
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
shape_size
=
std
::
ceil
(
static_cast
<
float
>
(
range_prim
->
limit
()
-
range_prim
->
start
())
/
range_prim
->
delta
());
std
::
vector
<
int
>
in_shape
(
1
);
in_shape
.
push_back
(
shape_size
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/rank.cc
浏览文件 @
2f36b91a
...
...
@@ -25,10 +25,13 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
std
::
vector
<
int
>
in_shape
(
1
,
1
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
in_shape
(
1
,
1
);
output
->
set_shape
(
in_shape
);
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/resize.cc
浏览文件 @
2f36b91a
...
...
@@ -66,6 +66,11 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
if
(
output
==
nullptr
)
{
return
1
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
new_height
=
GetNewHeight
();
auto
new_width
=
GetNewWidth
();
...
...
@@ -75,10 +80,8 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
output_shape
.
push_back
(
new_width
);
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/reverse_sequence.cc
浏览文件 @
2f36b91a
...
...
@@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector<tensor::Tensor *> inputs, std::vecto
auto
output
=
outputs
.
front
();
MS_ASSERT
(
input
!=
nullptr
);
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/roi_pooling.cc
浏览文件 @
2f36b91a
...
...
@@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
ROIPooling
=
this
->
primitive
->
value_as_ROIPooling
();
auto
new_h
=
ROIPooling
->
pooledH
();
auto
new_w
=
ROIPooling
->
pooledW
();
...
...
@@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
output_shape
.
push_back
(
new_w
);
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/scatter_nd.cc
浏览文件 @
2f36b91a
...
...
@@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
return
RET_ERROR
;
}
auto
output
=
outputs_
.
front
();
output
->
set_data_type
(
update
->
data_type
());
output
->
SetFormat
(
update
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
shape_data
=
reinterpret_cast
<
int
*>
(
shape
->
Data
());
std
::
vector
<
int
>
out_shape
(
shape_data
,
shape_data
+
shape
->
DataSize
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
update
->
data_type
());
output
->
SetFormat
(
update
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/space_to_batch.cc
浏览文件 @
2f36b91a
...
...
@@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"space_to_batch only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
/
block_sizes_
[
NHWC_H
];
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
];
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/space_to_depth.cc
浏览文件 @
2f36b91a
...
...
@@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG
(
ERROR
)
<<
"space_to_depth only support NHWC now!"
;
return
1
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
if
(
input_shape
.
size
()
!=
kDimension_4d
)
{
MS_LOG
(
ERROR
)
<<
"input shape dimension size should == "
<<
kDimension_4d
;
...
...
@@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
[
NHWC_W
]
=
input_shape
[
NHWC_W
]
/
block_size
;
output_shape
[
NHWC_C
]
=
input_shape
[
NHWC_C
]
*
(
block_size
*
block_size
);
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
0
;
return
RET_OK
;
}
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/split.cc
浏览文件 @
2f36b91a
...
...
@@ -66,6 +66,13 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG
(
ERROR
)
<<
"outputs number is not equal to "
<<
number_split
;
return
RET_ERROR
;
}
for
(
int
i
=
0
;
i
<
number_split
;
++
i
)
{
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
split_dim
=
spilt_prim
->
splitDim
();
std
::
vector
<
int
>
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
size_split
;
...
...
mindspore/lite/src/ops/squeeze.cc
浏览文件 @
2f36b91a
...
...
@@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
return
-
1
;
}
auto
*
in_tensor
=
inputs_
.
front
();
outputs_
.
front
()
->
set_data_type
(
in_tensor
->
data_type
());
outputs_
.
front
()
->
SetFormat
(
in_tensor
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
in_shape
=
in_tensor
->
shape
();
std
::
vector
<
int
>
out_shape
;
// todo: getAxis
...
...
@@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
}
outputs_
.
front
()
->
set_shape
(
out_shape
);
outputs_
.
front
()
->
set_data_type
(
in_tensor
->
data_type
());
outputs_
.
front
()
->
SetFormat
(
in_tensor
->
GetFormat
());
return
0
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/stack.cc
浏览文件 @
2f36b91a
...
...
@@ -56,6 +56,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
input_shape
=
input
->
shape
();
auto
stack_prim
=
this
->
primitive
->
value_as_Stack
();
std
::
vector
<
int32_t
>
output_shape
=
input_shape
;
...
...
@@ -84,8 +89,6 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
output_shape
.
insert
(
output_shape
.
begin
()
+
axis
,
inputs
.
size
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/strided_slice.cc
浏览文件 @
2f36b91a
...
...
@@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
return
RET_PARAM_INVALID
;
}
auto
input
=
inputs
.
at
(
0
);
outputs
.
front
()
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
MS_ASSERT
(
input
!=
nullptr
);
auto
input_shape
=
input
->
shape
();
std
::
vector
<
int
>
output_shape
;
...
...
@@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape
=
ApplyShrinkMask
(
output_shape
);
outputs
.
front
()
->
set_shape
(
output_shape
);
outputs
.
front
()
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/tile.cc
浏览文件 @
2f36b91a
...
...
@@ -40,6 +40,11 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
tile_prim
=
this
->
primitive
->
value_as_Tile
();
MS_ASSERT
(
tile_prim
!=
nullptr
);
std
::
vector
<
int
>
out_shape
;
...
...
@@ -49,9 +54,8 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
int
tmp
=
input
->
shape
()[
i
]
*
multiples
[
i
];
out_shape
.
push_back
(
tmp
);
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/topk.cc
浏览文件 @
2f36b91a
...
...
@@ -46,16 +46,19 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT
(
output0
!=
nullptr
);
auto
output1
=
outputs_
.
at
(
1
);
MS_ASSERT
(
output1
!=
nullptr
);
output0
->
set_data_type
(
input
->
data_type
());
output0
->
SetFormat
(
input
->
GetFormat
());
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
auto
topk_prim
=
this
->
primitive
->
value_as_TopK
();
MS_ASSERT
(
topk_prim
!=
nullptr
);
auto
out_shape
=
input
->
shape
();
out_shape
[
out_shape
.
size
()
-
1
]
=
topk_prim
->
k
();
output0
->
set_shape
(
out_shape
);
output0
->
set_data_type
(
input
->
data_type
());
output0
->
SetFormat
(
input
->
GetFormat
());
output1
->
set_shape
(
out_shape
);
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/unique.cc
浏览文件 @
2f36b91a
...
...
@@ -42,12 +42,15 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT
(
output0
!=
nullptr
);
auto
&
output1
=
outputs_
.
at
(
1
);
MS_ASSERT
(
output1
!=
nullptr
);
output0
->
set_shape
(
input
->
shape
());
output0
->
set_data_type
(
input
->
data_type
());
output1
->
set_shape
(
input
->
shape
());
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
output0
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output0
->
set_shape
(
input
->
shape
());
output1
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/unstack.cc
浏览文件 @
2f36b91a
...
...
@@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
MS_LOG
(
ERROR
)
<<
"Invalid axis "
<<
prim
->
axis
();
return
RET_PARAM_INVALID
;
}
for
(
auto
&
out
:
outputs
)
{
MS_ASSERT
(
out
!=
nullptr
);
out
->
set_data_type
(
input
->
data_type
());
out
->
SetFormat
(
input
->
GetFormat
());
}
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
std
::
vector
<
int
>
output_shape
;
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
if
(
i
!=
axis
)
{
...
...
@@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
for
(
auto
&
out
:
outputs
)
{
MS_ASSERT
(
out
!=
nullptr
);
out
->
set_shape
(
output_shape
);
out
->
set_data_type
(
input
->
data_type
());
out
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
...
...
mindspore/lite/src/ops/where.cc
浏览文件 @
2f36b91a
...
...
@@ -53,6 +53,11 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
input0
=
inputs_
.
at
(
0
);
auto
input1
=
inputs_
.
at
(
1
);
auto
input2
=
inputs_
.
at
(
2
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
int
num
=
input0
->
ElementsNum
();
int
num1
=
input1
->
ElementsNum
();
int
num2
=
input2
->
ElementsNum
();
...
...
@@ -85,8 +90,6 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto
output_shape
=
shape_tmp
;
output_shape
[
axisout
]
=
nummax
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/ops/zeros_like.cc
浏览文件 @
2f36b91a
...
...
@@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
<<
", output size: "
<<
outputs_
.
size
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
if
(
!
GetInferFlag
())
{
return
RET_OK
;
}
output
->
set_shape
(
input
->
shape
());
return
RET_OK
;
}
}
// namespace lite
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c
浏览文件 @
2f36b91a
...
...
@@ -18,15 +18,29 @@
#include <float.h>
int
ArgCompareAscFp32
(
const
void
*
a
,
const
void
*
b
)
{
return
((
ArgElement
*
)
a
)
->
data_
.
f_data_
-
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
float
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
float
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
if
(
b_value
>
a_value
)
{
return
-
1
;
}
if
(
b_value
<
a_value
)
{
return
1
;
}
return
0
;
}
int
ArgCompareDescFp32
(
const
void
*
a
,
const
void
*
b
)
{
// cmp funtion of qsort must return int type
auto
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
auto
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
int
res
=
b_value
>
a_value
?
1
:
-
1
;
return
res
;
float
b_value
=
((
ArgElement
*
)
b
)
->
data_
.
f_data_
;
float
a_value
=
((
ArgElement
*
)
a
)
->
data_
.
f_data_
;
if
(
b_value
>
a_value
)
{
return
1
;
}
if
(
b_value
<
a_value
)
{
return
-
1
;
}
return
0
;
}
void
ArgMaxDim0OutValue
(
const
float
*
input
,
float
*
output
,
const
int
*
in_shape
,
ArgMinMaxParameter
*
param
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录