Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
78c91228
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
78c91228
编写于
7月 29, 2020
作者:
C
cjh9368
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug for anf_exporter graph input tensor format and op output format
上级
1b699234
变更
44
隐藏空白更改
内联
并排
Showing
44 changed file
with
100 addition
and
60 deletion
+100
-60
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+1
-0
mindspore/lite/src/ops/addn.cc
mindspore/lite/src/ops/addn.cc
+1
-0
mindspore/lite/src/ops/argmax.cc
mindspore/lite/src/ops/argmax.cc
+1
-0
mindspore/lite/src/ops/argmin.cc
mindspore/lite/src/ops/argmin.cc
+1
-1
mindspore/lite/src/ops/arithmetic.cc
mindspore/lite/src/ops/arithmetic.cc
+3
-2
mindspore/lite/src/ops/arithmetic_self.cc
mindspore/lite/src/ops/arithmetic_self.cc
+3
-1
mindspore/lite/src/ops/batch_to_space.cc
mindspore/lite/src/ops/batch_to_space.cc
+2
-1
mindspore/lite/src/ops/broadcast_to.cc
mindspore/lite/src/ops/broadcast_to.cc
+1
-1
mindspore/lite/src/ops/cast.cc
mindspore/lite/src/ops/cast.cc
+1
-1
mindspore/lite/src/ops/concat.cc
mindspore/lite/src/ops/concat.cc
+2
-1
mindspore/lite/src/ops/crop.cc
mindspore/lite/src/ops/crop.cc
+2
-1
mindspore/lite/src/ops/depth_to_space.cc
mindspore/lite/src/ops/depth_to_space.cc
+3
-2
mindspore/lite/src/ops/expand_dims.cc
mindspore/lite/src/ops/expand_dims.cc
+2
-1
mindspore/lite/src/ops/fill.cc
mindspore/lite/src/ops/fill.cc
+2
-1
mindspore/lite/src/ops/flatten.cc
mindspore/lite/src/ops/flatten.cc
+2
-1
mindspore/lite/src/ops/fullconnection.cc
mindspore/lite/src/ops/fullconnection.cc
+2
-1
mindspore/lite/src/ops/gather.cc
mindspore/lite/src/ops/gather.cc
+2
-1
mindspore/lite/src/ops/gather_nd.cc
mindspore/lite/src/ops/gather_nd.cc
+2
-1
mindspore/lite/src/ops/matmul.cc
mindspore/lite/src/ops/matmul.cc
+2
-1
mindspore/lite/src/ops/one_hot.cc
mindspore/lite/src/ops/one_hot.cc
+2
-0
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+2
-1
mindspore/lite/src/ops/pad.cc
mindspore/lite/src/ops/pad.cc
+1
-1
mindspore/lite/src/ops/pooling.cc
mindspore/lite/src/ops/pooling.cc
+1
-0
mindspore/lite/src/ops/range.cc
mindspore/lite/src/ops/range.cc
+2
-1
mindspore/lite/src/ops/rank.cc
mindspore/lite/src/ops/rank.cc
+2
-1
mindspore/lite/src/ops/reduce.cc
mindspore/lite/src/ops/reduce.cc
+2
-0
mindspore/lite/src/ops/reshape.cc
mindspore/lite/src/ops/reshape.cc
+2
-1
mindspore/lite/src/ops/resize.cc
mindspore/lite/src/ops/resize.cc
+2
-1
mindspore/lite/src/ops/scatter_nd.cc
mindspore/lite/src/ops/scatter_nd.cc
+2
-1
mindspore/lite/src/ops/slice.cc
mindspore/lite/src/ops/slice.cc
+7
-6
mindspore/lite/src/ops/softmax.cc
mindspore/lite/src/ops/softmax.cc
+2
-1
mindspore/lite/src/ops/split.cc
mindspore/lite/src/ops/split.cc
+1
-1
mindspore/lite/src/ops/squeeze.cc
mindspore/lite/src/ops/squeeze.cc
+15
-15
mindspore/lite/src/ops/stack.cc
mindspore/lite/src/ops/stack.cc
+3
-2
mindspore/lite/src/ops/strided_slice.cc
mindspore/lite/src/ops/strided_slice.cc
+2
-0
mindspore/lite/src/ops/tile.cc
mindspore/lite/src/ops/tile.cc
+1
-1
mindspore/lite/src/ops/topk.cc
mindspore/lite/src/ops/topk.cc
+4
-3
mindspore/lite/src/ops/transpose.cc
mindspore/lite/src/ops/transpose.cc
+2
-1
mindspore/lite/src/ops/unique.cc
mindspore/lite/src/ops/unique.cc
+3
-1
mindspore/lite/src/ops/unsqueeze.cc
mindspore/lite/src/ops/unsqueeze.cc
+1
-1
mindspore/lite/src/ops/unstack.cc
mindspore/lite/src/ops/unstack.cc
+1
-1
mindspore/lite/src/ops/where.cc
mindspore/lite/src/ops/where.cc
+2
-1
mindspore/lite/src/ops/zeroslike.cc
mindspore/lite/src/ops/zeroslike.cc
+2
-1
mindspore/lite/tools/benchmark/benchmark.cc
mindspore/lite/tools/benchmark/benchmark.cc
+1
-0
未找到文件。
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
78c91228
...
...
@@ -150,6 +150,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto
tensor
=
metaGraphT
->
allTensors
[
input
].
get
();
if
(
tensor
->
data
.
empty
())
{
tensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
tensor
->
format
=
schema
::
Format_NHWC
;
// tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT;
metaGraphT
->
inputIndex
.
emplace_back
(
input
);
}
...
...
mindspore/lite/src/ops/addn.cc
浏览文件 @
78c91228
...
...
@@ -36,6 +36,7 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
return
RET_INPUT_TENSOR_ERROR
;
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
...
...
mindspore/lite/src/ops/argmax.cc
浏览文件 @
78c91228
...
...
@@ -40,6 +40,7 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output_shape
.
erase
(
output_shape
.
begin
()
+
axis
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
...
...
mindspore/lite/src/ops/argmin.cc
浏览文件 @
78c91228
...
...
@@ -39,9 +39,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std
::
vector
<
int
>
output_shape
(
input
->
shape
());
output_shape
.
erase
(
output_shape
.
begin
()
+
axis
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/arithmetic.cc
浏览文件 @
78c91228
...
...
@@ -39,7 +39,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto
input_shape0
=
input0
->
shape
();
auto
input_shape1
=
input1
->
shape
();
auto
format
=
input0
->
GetFormat
();
in_shape0_
.
resize
(
5
);
in_shape1_
.
resize
(
5
);
out_shape_
.
resize
(
5
);
...
...
@@ -57,6 +57,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
in_shape1_
[
i
]
=
input_shape1
[
i
];
}
format
=
input0
->
GetFormat
();
}
else
if
(
input_shape0
.
size
()
>
input_shape1
.
size
())
{
ndim_
=
input_shape0
.
size
();
auto
fill_dim_num
=
input_shape0
.
size
()
-
input_shape1
.
size
();
...
...
@@ -93,7 +94,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape
.
push_back
(
out_shape_
[
i
]);
}
output
->
SetFormat
(
format
);
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
return
RET_OK
;
...
...
mindspore/lite/src/ops/arithmetic_self.cc
浏览文件 @
78c91228
...
...
@@ -26,9 +26,11 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
input
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/batch_to_space.cc
浏览文件 @
78c91228
...
...
@@ -85,9 +85,10 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape
[
kNHWC_h_index
]
=
input_shape
[
kNHWC_h_index
]
*
block_shape
->
Get
(
0
)
-
crops
->
Get
(
0
)
-
crops
->
Get
(
1
);
output_shape
[
kNHWC_w_index
]
=
input_shape
[
kNHWC_w_index
]
*
block_shape
->
Get
(
1
)
-
crops
->
Get
(
2
)
-
crops
->
Get
(
3
);
output_shape
[
kNHWC_c_index
]
=
input_shape
[
kNHWC_c_index
];
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/broadcast_to.cc
浏览文件 @
78c91228
...
...
@@ -58,9 +58,9 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
shape
[
i
]
=
dst_shape
[
i
];
--
input_shape_index
;
}
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
outputs
[
0
]
->
set_shape
(
shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/cast.cc
浏览文件 @
78c91228
...
...
@@ -44,9 +44,9 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG
(
ERROR
)
<<
"Invalid output datatype "
<<
cast_prim
->
dstT
();
return
RET_INPUT_TENSOR_ERROR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/concat.cc
浏览文件 @
78c91228
...
...
@@ -70,7 +70,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output_shape
[
axis
]
=
output_axis_dim
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/crop.cc
浏览文件 @
78c91228
...
...
@@ -32,7 +32,8 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return
RET_PARAM_INVALID
;
}
outputs
[
0
]
->
set_shape
(
inputs
[
1
]
->
shape
());
outputs
[
0
]
->
SetFormat
(
inputs
[
1
]
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/depth_to_space.cc
浏览文件 @
78c91228
...
...
@@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace
{
constexpr
int
kDepthToSpaceOutputNum
=
1
;
constexpr
int
kDepthToSpaceInputNum
=
1
;
}
}
// namespace
int
DepthToSpace
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
...
...
@@ -56,7 +56,8 @@ int DepthToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape
[
kNHWC_c_index
]
=
input_shape
[
kNHWC_c_index
]
/
(
block_size
*
block_size
);
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/expand_dims.cc
浏览文件 @
78c91228
...
...
@@ -45,7 +45,8 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
out_shape
.
insert
(
out_shape
.
begin
()
+
dim
,
1
,
1
);
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/fill.cc
浏览文件 @
78c91228
...
...
@@ -42,7 +42,8 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
(
void
)
output_shape
.
insert
(
output_shape
.
begin
(),
fill_prim
->
dims
()
->
begin
(),
fill_prim
->
dims
()
->
end
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/flatten.cc
浏览文件 @
78c91228
...
...
@@ -43,7 +43,8 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/fullconnection.cc
浏览文件 @
78c91228
...
...
@@ -56,7 +56,8 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
out_shape
[
fc_prim
->
axis
()]
=
input1
->
shape
()[
0
];
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/gather.cc
浏览文件 @
78c91228
...
...
@@ -71,7 +71,8 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/gather_nd.cc
浏览文件 @
78c91228
...
...
@@ -59,7 +59,8 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/matmul.cc
浏览文件 @
78c91228
...
...
@@ -57,7 +57,8 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
y_shape
[
y_shape_size
-
1
]
=
w_shape
[
w_shape
.
size
()
-
1
];
output
->
set_shape
(
y_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/one_hot.cc
浏览文件 @
78c91228
...
...
@@ -67,6 +67,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
return
RET_NULL_PTR
;
}
output
->
set_data_type
(
on_value
->
data_type
());
output
->
SetFormat
(
on_value
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/ops.cc
浏览文件 @
78c91228
...
...
@@ -138,7 +138,8 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/pad.cc
浏览文件 @
78c91228
...
...
@@ -55,9 +55,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/pooling.cc
浏览文件 @
78c91228
...
...
@@ -74,6 +74,7 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape
.
at
(
2
)
=
output_w
;
output
->
set_shape
(
input_shape
);
output
->
set_data_type
(
input
->
data_type
());
// todo: temp fix
output
->
SetFormat
(
schema
::
Format_NHWC
);
return
RET_OK
;
...
...
mindspore/lite/src/ops/range.cc
浏览文件 @
78c91228
...
...
@@ -34,7 +34,8 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
in_shape
.
push_back
(
shape_size
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/rank.cc
浏览文件 @
78c91228
...
...
@@ -29,7 +29,8 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std
::
vector
<
int
>
in_shape
(
1
,
1
);
output
->
set_shape
(
in_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/reduce.cc
浏览文件 @
78c91228
...
...
@@ -73,6 +73,8 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/reshape.cc
浏览文件 @
78c91228
...
...
@@ -114,7 +114,8 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/resize.cc
浏览文件 @
78c91228
...
...
@@ -45,7 +45,8 @@ int Resize::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/scatter_nd.cc
浏览文件 @
78c91228
...
...
@@ -57,7 +57,8 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
std
::
vector
<
int
>
out_shape
(
shape_data
,
shape_data
+
sizeof
(
shape_data
)
/
sizeof
(
shape_data
[
0
]));
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
update
->
data_type
());
output
->
SetFormat
(
update
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/slice.cc
浏览文件 @
78c91228
...
...
@@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace
{
constexpr
int
kSliceInputNum
=
1
;
constexpr
int
kSliceOutputNum
=
1
;
}
}
// namespace
int
Slice
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
...
...
@@ -47,13 +47,13 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return
RET_PARAM_INVALID
;
}
if
(
input_shape
[
i
]
<=
slice_begin
[
i
])
{
MS_LOG
(
ERROR
)
<<
"Invalid begin input!begin["
<<
i
<<
"]="
<<
slice_begin
[
i
]
<<
" which should be <= "
<<
input_shape
[
i
];
MS_LOG
(
ERROR
)
<<
"Invalid begin input!begin["
<<
i
<<
"]="
<<
slice_begin
[
i
]
<<
" which should be <= "
<<
input_shape
[
i
];
return
RET_PARAM_INVALID
;
}
if
(
slice_size
[
i
]
>
(
input_shape
[
i
]
-
slice_begin
[
i
]))
{
MS_LOG
(
ERROR
)
<<
"Invalid size input "
<<
slice_size
[
i
]
<<
" which should be <= "
<<
input_shape
[
i
]
-
slice_begin
[
i
];
MS_LOG
(
ERROR
)
<<
"Invalid size input "
<<
slice_size
[
i
]
<<
" which should be <= "
<<
input_shape
[
i
]
-
slice_begin
[
i
];
return
RET_PARAM_INVALID
;
}
...
...
@@ -62,7 +62,8 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/softmax.cc
浏览文件 @
78c91228
...
...
@@ -28,7 +28,8 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT
(
output
!=
nullptr
);
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/split.cc
浏览文件 @
78c91228
...
...
@@ -55,8 +55,8 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
output_shape
[
split_dim
]
=
split_dim_i
;
outputs_
[
i
]
->
set_shape
(
output_shape
);
outputs_
[
i
]
->
set_data_type
(
input
->
data_type
());
outputs_
[
i
]
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/squeeze.cc
浏览文件 @
78c91228
...
...
@@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace
{
constexpr
int
kSqueezeInputNum
=
1
;
constexpr
int
kSqueezeOutputNum
=
1
;
}
}
// namespace
int
Squeeze
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
if
(
kSqueezeInputNum
!=
inputs_
.
size
())
{
...
...
@@ -45,31 +45,31 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
std
::
vector
<
int
>
axes_
;
for
(
auto
iter
=
axis
->
begin
();
iter
!=
axis
->
end
();
iter
++
)
{
axes_
.
push_back
(
*
iter
);
}
}
if
(
axes_
.
size
()
==
0
)
{
for
(
int
i
=
0
;
i
<
in_shape
.
size
();
i
++
)
{
if
(
in_shape
[
i
]
!=
1
)
{
out_shape
.
push_back
(
in_shape
[
i
]);
}
if
(
in_shape
[
i
]
!=
1
)
{
out_shape
.
push_back
(
in_shape
[
i
]);
}
}
}
else
{
int
axisIdx
=
0
;
for
(
int
i
=
0
;
i
<
in_shape
.
size
();
i
++
)
{
if
(
axisIdx
<
axes_
.
size
()
&&
axes_
[
axisIdx
]
==
i
)
{
MS_ASSERT
(
in_shape
[
i
]
==
1
);
axisIdx
++
;
continue
;
}
else
{
out_shape
.
push_back
(
in_shape
[
i
]);
}
int
axisIdx
=
0
;
for
(
int
i
=
0
;
i
<
in_shape
.
size
();
i
++
)
{
if
(
axisIdx
<
axes_
.
size
()
&&
axes_
[
axisIdx
]
==
i
)
{
MS_ASSERT
(
in_shape
[
i
]
==
1
);
axisIdx
++
;
continue
;
}
else
{
out_shape
.
push_back
(
in_shape
[
i
]);
}
}
}
outputs_
.
front
()
->
set_shape
(
out_shape
);
outputs_
.
front
()
->
set_data_type
(
in_tensor
->
data_type
());
outputs_
.
front
()
->
SetFormat
(
in_tensor
->
GetFormat
());
return
0
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/stack.cc
浏览文件 @
78c91228
...
...
@@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace
{
constexpr
int
kStackOutputNum
=
1
;
constexpr
int
kStackMinInputNum
=
2
;
}
}
// namespace
int
Stack
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
...
...
@@ -61,7 +61,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
output_shape
.
insert
(
output_shape
.
begin
()
+
axis
,
inputs
.
size
());
outputs
[
0
]
->
set_shape
(
output_shape
);
outputs
[
0
]
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/strided_slice.cc
浏览文件 @
78c91228
...
...
@@ -157,6 +157,8 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
outputs
.
front
()
->
set_shape
(
output_shape
);
outputs
.
front
()
->
set_data_type
(
input
->
data_type
());
outputs
[
0
]
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/tile.cc
浏览文件 @
78c91228
...
...
@@ -37,9 +37,9 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
out_shape
.
push_back
(
tmp
);
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/topk.cc
浏览文件 @
78c91228
...
...
@@ -37,12 +37,13 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
output0
->
set_shape
(
input
->
shape
());
output0
->
set_data_type
(
input
->
data_type
());
// output0->shape().back() = topk_prim->k();
// output0->shape().back() = topk_prim->k();
output1
->
set_shape
(
input
->
shape
());
output1
->
set_data_type
(
input
->
data_type
());
// output1->shape().back() = topk_prim->k();
// output1->shape().back() = topk_prim->k();
output1
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/transpose.cc
浏览文件 @
78c91228
...
...
@@ -47,7 +47,8 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/unique.cc
浏览文件 @
78c91228
...
...
@@ -36,7 +36,9 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output0
->
set_data_type
(
input
->
data_type
());
output1
->
set_shape
(
input
->
shape
());
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
output0
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/unsqueeze.cc
浏览文件 @
78c91228
...
...
@@ -65,9 +65,9 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
}
output
->
SetFormat
(
input
->
GetFormat
());
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/unstack.cc
浏览文件 @
78c91228
...
...
@@ -41,8 +41,8 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
MS_ASSERT
(
out
!=
nullptr
);
out
->
set_shape
(
output_shape
);
out
->
set_data_type
(
input
->
data_type
());
out
->
SetFormat
(
input
->
GetFormat
());
}
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/where.cc
浏览文件 @
78c91228
...
...
@@ -73,7 +73,8 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
output_shape
[
axisout
]
=
nummax
;
outputs_
[
0
]
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/zeroslike.cc
浏览文件 @
78c91228
...
...
@@ -33,7 +33,8 @@ int ZerosLike::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
output
->
set_shape
(
input
->
shape
());
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/tools/benchmark/benchmark.cc
浏览文件 @
78c91228
...
...
@@ -330,6 +330,7 @@ int Benchmark::MarkAccuracy() {
}
ReadCalibData
();
CompareOutput
();
if
(
cleanData
)
{
for
(
auto
&
msOutput
:
msOutputs
)
{
for
(
auto
&
outputTensor
:
msOutput
.
second
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录