Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
80d570f0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80d570f0
编写于
8月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4999 change long to int64
Merge pull request !4999 from yeyunpeng2020/primitive
上级
dde25759
6b46acb3
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
119 addition
and
72 deletion
+119
-72
mindspore/lite/src/ops/crop.cc
mindspore/lite/src/ops/crop.cc
+9
-9
mindspore/lite/src/ops/crop.h
mindspore/lite/src/ops/crop.h
+4
-4
mindspore/lite/src/ops/detection_post_process.cc
mindspore/lite/src/ops/detection_post_process.cc
+16
-16
mindspore/lite/src/ops/detection_post_process.h
mindspore/lite/src/ops/detection_post_process.h
+8
-8
mindspore/lite/src/ops/permute.cc
mindspore/lite/src/ops/permute.cc
+5
-5
mindspore/lite/src/ops/permute.h
mindspore/lite/src/ops/permute.h
+2
-2
mindspore/lite/src/ops/primitive_c.cc
mindspore/lite/src/ops/primitive_c.cc
+56
-9
mindspore/lite/src/ops/reshape.cc
mindspore/lite/src/ops/reshape.cc
+5
-5
mindspore/lite/src/ops/reshape.h
mindspore/lite/src/ops/reshape.h
+2
-2
mindspore/lite/src/ops/resize.cc
mindspore/lite/src/ops/resize.cc
+8
-8
mindspore/lite/src/ops/resize.h
mindspore/lite/src/ops/resize.h
+4
-4
未找到文件。
mindspore/lite/src/ops/crop.cc
浏览文件 @
80d570f0
...
...
@@ -19,22 +19,22 @@
namespace
mindspore
{
namespace
lite
{
#ifdef PRIMITIVE_WRITEABLE
long
Crop
::
GetAxis
()
const
{
return
this
->
primitive_
->
value
.
AsCrop
()
->
axis
;
}
std
::
vector
<
long
>
Crop
::
GetOffsets
()
const
{
return
this
->
primitive_
->
value
.
AsCrop
()
->
offsets
;
}
int64_t
Crop
::
GetAxis
()
const
{
return
this
->
primitive_
->
value
.
AsCrop
()
->
axis
;
}
std
::
vector
<
int64_t
>
Crop
::
GetOffsets
()
const
{
return
this
->
primitive_
->
value
.
AsCrop
()
->
offsets
;
}
void
Crop
::
SetAxis
(
long
axis
)
{
this
->
primitive_
->
value
.
AsCrop
()
->
axis
=
axis
;
}
void
Crop
::
SetOffsets
(
const
std
::
vector
<
long
>
&
offsets
)
{
this
->
primitive_
->
value
.
AsCrop
()
->
offsets
=
offsets
;
}
void
Crop
::
SetAxis
(
int64_t
axis
)
{
this
->
primitive_
->
value
.
AsCrop
()
->
axis
=
axis
;
}
void
Crop
::
SetOffsets
(
const
std
::
vector
<
int64_t
>
&
offsets
)
{
this
->
primitive_
->
value
.
AsCrop
()
->
offsets
=
offsets
;
}
#else
long
Crop
::
GetAxis
()
const
{
return
this
->
primitive_
->
value_as_Crop
()
->
axis
();
}
std
::
vector
<
long
>
Crop
::
GetOffsets
()
const
{
int64_t
Crop
::
GetAxis
()
const
{
return
this
->
primitive_
->
value_as_Crop
()
->
axis
();
}
std
::
vector
<
int64_t
>
Crop
::
GetOffsets
()
const
{
auto
fb_vector
=
this
->
primitive_
->
value_as_Crop
()
->
offsets
();
return
std
::
vector
<
long
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
return
std
::
vector
<
int64_t
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
}
void
Crop
::
SetAxis
(
long
axis
)
{}
void
Crop
::
SetOffsets
(
const
std
::
vector
<
long
>
&
offsets
)
{}
void
Crop
::
SetAxis
(
int64_t
axis
)
{}
void
Crop
::
SetOffsets
(
const
std
::
vector
<
int64_t
>
&
offsets
)
{}
#endif
namespace
{
constexpr
int
kCropOutputNum
=
1
;
...
...
mindspore/lite/src/ops/crop.h
浏览文件 @
80d570f0
...
...
@@ -34,10 +34,10 @@ class Crop : public PrimitiveC {
explicit
Crop
(
schema
::
Primitive
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
#endif
int
InferShape
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
)
override
;
long
GetAxis
()
const
;
std
::
vector
<
long
>
GetOffsets
()
const
;
void
SetAxis
(
long
axis
);
void
SetOffsets
(
const
std
::
vector
<
long
>
&
offsets
);
int64_t
GetAxis
()
const
;
std
::
vector
<
int64_t
>
GetOffsets
()
const
;
void
SetAxis
(
int64_t
axis
);
void
SetOffsets
(
const
std
::
vector
<
int64_t
>
&
offsets
);
};
}
// namespace lite
}
// namespace mindspore
...
...
mindspore/lite/src/ops/detection_post_process.cc
浏览文件 @
80d570f0
...
...
@@ -31,16 +31,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const {
float
DetectionPostProcess
::
GetNmsScoreThreshold
()
const
{
return
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
NmsScoreThreshold
;
}
long
DetectionPostProcess
::
GetMaxDetections
()
const
{
int64_t
DetectionPostProcess
::
GetMaxDetections
()
const
{
return
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
MaxDetections
;
}
long
DetectionPostProcess
::
GetDetectionsPreClass
()
const
{
int64_t
DetectionPostProcess
::
GetDetectionsPreClass
()
const
{
return
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
DetectionsPreClass
;
}
long
DetectionPostProcess
::
GetMaxClassesPreDetection
()
const
{
int64_t
DetectionPostProcess
::
GetMaxClassesPreDetection
()
const
{
return
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
MaxClassesPreDetection
;
}
long
DetectionPostProcess
::
GetNumClasses
()
const
{
int64_t
DetectionPostProcess
::
GetNumClasses
()
const
{
return
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
NumClasses
;
}
bool
DetectionPostProcess
::
GetUseRegularNms
()
const
{
...
...
@@ -71,16 +71,16 @@ void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {
void
DetectionPostProcess
::
SetNmsScoreThreshold
(
float
nms_score_threshold
)
{
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
NmsScoreThreshold
=
nms_score_threshold
;
}
void
DetectionPostProcess
::
SetMaxDetections
(
long
max_detections
)
{
void
DetectionPostProcess
::
SetMaxDetections
(
int64_t
max_detections
)
{
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
MaxClassesPreDetection
=
max_detections
;
}
void
DetectionPostProcess
::
SetDetectionsPreClass
(
long
detections_pre_class
)
{
void
DetectionPostProcess
::
SetDetectionsPreClass
(
int64_t
detections_pre_class
)
{
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
DetectionsPreClass
=
detections_pre_class
;
}
void
DetectionPostProcess
::
SetMaxClassesPreDetection
(
long
max_classes_pre_detection
)
{
void
DetectionPostProcess
::
SetMaxClassesPreDetection
(
int64_t
max_classes_pre_detection
)
{
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
MaxClassesPreDetection
=
max_classes_pre_detection
;
}
void
DetectionPostProcess
::
SetNumClasses
(
long
num_classes
)
{
void
DetectionPostProcess
::
SetNumClasses
(
int64_t
num_classes
)
{
this
->
primitive_
->
value
.
AsDetectionPostProcess
()
->
NumClasses
=
num_classes
;
}
void
DetectionPostProcess
::
SetUseRegularNms
(
bool
use_regular_nms
)
{
...
...
@@ -103,16 +103,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const {
float
DetectionPostProcess
::
GetNmsScoreThreshold
()
const
{
return
this
->
primitive_
->
value_as_DetectionPostProcess
()
->
NmsScoreThreshold
();
}
long
DetectionPostProcess
::
GetMaxDetections
()
const
{
int64_t
DetectionPostProcess
::
GetMaxDetections
()
const
{
return
this
->
primitive_
->
value_as_DetectionPostProcess
()
->
MaxDetections
();
}
long
DetectionPostProcess
::
GetDetectionsPreClass
()
const
{
int64_t
DetectionPostProcess
::
GetDetectionsPreClass
()
const
{
return
this
->
primitive_
->
value_as_DetectionPostProcess
()
->
DetectionsPreClass
();
}
long
DetectionPostProcess
::
GetMaxClassesPreDetection
()
const
{
int64_t
DetectionPostProcess
::
GetMaxClassesPreDetection
()
const
{
return
this
->
primitive_
->
value_as_DetectionPostProcess
()
->
MaxClassesPreDetection
();
}
long
DetectionPostProcess
::
GetNumClasses
()
const
{
int64_t
DetectionPostProcess
::
GetNumClasses
()
const
{
return
this
->
primitive_
->
value_as_DetectionPostProcess
()
->
NumClasses
();
}
bool
DetectionPostProcess
::
GetUseRegularNms
()
const
{
...
...
@@ -127,10 +127,10 @@ void DetectionPostProcess::SetXScale(float x_scale) {}
void
DetectionPostProcess
::
SetYScale
(
float
y_scale
)
{}
void
DetectionPostProcess
::
SetNmsIouThreshold
(
float
nms_iou_threshold
)
{}
void
DetectionPostProcess
::
SetNmsScoreThreshold
(
float
nms_score_threshold
)
{}
void
DetectionPostProcess
::
SetMaxDetections
(
long
max_detections
)
{}
void
DetectionPostProcess
::
SetDetectionsPreClass
(
long
detections_pre_class
)
{}
void
DetectionPostProcess
::
SetMaxClassesPreDetection
(
long
max_classes_pre_detection
)
{}
void
DetectionPostProcess
::
SetNumClasses
(
long
num_classes
)
{}
void
DetectionPostProcess
::
SetMaxDetections
(
int64_t
max_detections
)
{}
void
DetectionPostProcess
::
SetDetectionsPreClass
(
int64_t
detections_pre_class
)
{}
void
DetectionPostProcess
::
SetMaxClassesPreDetection
(
int64_t
max_classes_pre_detection
)
{}
void
DetectionPostProcess
::
SetNumClasses
(
int64_t
num_classes
)
{}
void
DetectionPostProcess
::
SetUseRegularNms
(
bool
use_regular_nms
)
{}
#endif
}
// namespace lite
...
...
mindspore/lite/src/ops/detection_post_process.h
浏览文件 @
80d570f0
...
...
@@ -41,10 +41,10 @@ class DetectionPostProcess : public PrimitiveC {
float
GetYScale
()
const
;
float
GetNmsIouThreshold
()
const
;
float
GetNmsScoreThreshold
()
const
;
long
GetMaxDetections
()
const
;
long
GetDetectionsPreClass
()
const
;
long
GetMaxClassesPreDetection
()
const
;
long
GetNumClasses
()
const
;
int64_t
GetMaxDetections
()
const
;
int64_t
GetDetectionsPreClass
()
const
;
int64_t
GetMaxClassesPreDetection
()
const
;
int64_t
GetNumClasses
()
const
;
bool
GetUseRegularNms
()
const
;
void
SetFormat
(
int
format
);
void
SetInputSize
(
int
input_size
);
...
...
@@ -54,10 +54,10 @@ class DetectionPostProcess : public PrimitiveC {
void
SetYScale
(
float
y_scale
);
void
SetNmsIouThreshold
(
float
nms_iou_threshold
);
void
SetNmsScoreThreshold
(
float
nms_score_threshold
);
void
SetMaxDetections
(
long
max_detections
);
void
SetDetectionsPreClass
(
long
detections_pre_class
);
void
SetMaxClassesPreDetection
(
long
max_classes_pre_detection
);
void
SetNumClasses
(
long
num_classes
);
void
SetMaxDetections
(
int64_t
max_detections
);
void
SetDetectionsPreClass
(
int64_t
detections_pre_class
);
void
SetMaxClassesPreDetection
(
int64_t
max_classes_pre_detection
);
void
SetNumClasses
(
int64_t
num_classes
);
void
SetUseRegularNms
(
bool
use_regular_nms
);
};
}
// namespace lite
...
...
mindspore/lite/src/ops/permute.cc
浏览文件 @
80d570f0
...
...
@@ -19,18 +19,18 @@
namespace
mindspore
{
namespace
lite
{
#ifdef PRIMITIVE_WRITEABLE
std
::
vector
<
long
>
Permute
::
GetOrder
()
const
{
return
this
->
primitive_
->
value
.
AsPermute
()
->
order
;
}
std
::
vector
<
int64_t
>
Permute
::
GetOrder
()
const
{
return
this
->
primitive_
->
value
.
AsPermute
()
->
order
;
}
void
Permute
::
SetOrder
(
const
std
::
vector
<
long
>
&
order
)
{
this
->
primitive_
->
value
.
AsPermute
()
->
order
=
order
;
}
void
Permute
::
SetOrder
(
const
std
::
vector
<
int64_t
>
&
order
)
{
this
->
primitive_
->
value
.
AsPermute
()
->
order
=
order
;
}
#else
std
::
vector
<
long
>
Permute
::
GetOrder
()
const
{
std
::
vector
<
int64_t
>
Permute
::
GetOrder
()
const
{
auto
fb_vector
=
this
->
primitive_
->
value_as_Permute
()
->
order
();
return
std
::
vector
<
long
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
return
std
::
vector
<
int64_t
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
}
void
Permute
::
SetOrder
(
const
std
::
vector
<
long
>
&
order
)
{}
void
Permute
::
SetOrder
(
const
std
::
vector
<
int64_t
>
&
order
)
{}
#endif
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/permute.h
浏览文件 @
80d570f0
...
...
@@ -33,8 +33,8 @@ class Permute : public PrimitiveC {
#else
explicit
Permute
(
schema
::
Primitive
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
#endif
std
::
vector
<
long
>
GetOrder
()
const
;
void
SetOrder
(
const
std
::
vector
<
long
>
&
order
);
std
::
vector
<
int64_t
>
GetOrder
()
const
;
void
SetOrder
(
const
std
::
vector
<
int64_t
>
&
order
);
};
}
// namespace lite
}
// namespace mindspore
...
...
mindspore/lite/src/ops/primitive_c.cc
浏览文件 @
80d570f0
...
...
@@ -410,11 +410,32 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return
new
Shape
(
primitive
);
case
schema
::
PrimitiveType_Unsqueeze
:
return
new
Unsqueeze
(
primitive
);
case
schema
::
PrimitiveType_BatchToSpace
:
return
new
BatchToSpace
(
primitive
);
case
schema
::
PrimitiveType_SpaceToBatch
:
return
new
SpaceToBatch
(
primitive
);
case
schema
::
PrimitiveType_BroadcastTo
:
return
new
BroadcastTo
(
primitive
);
case
schema
::
PrimitiveType_DepthToSpace
:
return
new
DepthToSpace
(
primitive
);
case
schema
::
PrimitiveType_Lstm
:
return
new
Lstm
(
primitive
);
case
schema
::
PrimitiveType_ZerosLike
:
return
new
ZerosLike
(
primitive
);
case
schema
::
PrimitiveType_MakeTuple
:
return
new
MakeTuple
(
primitive
);
case
schema
::
PrimitiveType_Where
:
return
new
Where
(
primitive
);
case
schema
::
PrimitiveType_ScatterND
:
return
new
ScatterND
(
primitive
);
case
schema
::
PrimitiveType_ConstantOfShape
:
return
new
ConstantOfShape
(
primitive
);
default:
MS_LOG
(
ERROR
)
<<
"Unsupported primitive type in UnPackFromSchemaPrimitiveT : "
<<
schema
::
EnumNamePrimitiveType
(
op_type
);
return
nullptr
;
break
;
}
return
nullptr
;
}
#else
PrimitiveC
*
PrimitiveC
::
UnPackFromSchemaPrimitive
(
mindspore
::
schema
::
Primitive
*
primitive
)
{
...
...
@@ -433,6 +454,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
Reduce
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Pooling
:
return
new
Pooling
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ROIPooling
:
return
new
ROIPooling
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_DepthwiseConv2D
:
return
new
DepthwiseConv2D
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_FusedBatchNorm
:
...
...
@@ -443,6 +466,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
FullConnection
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Power
:
return
new
Power
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Pad
:
return
new
Pad
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Range
:
return
new
Range
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Mul
:
...
...
@@ -469,20 +494,22 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
Scale
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Eltwise
:
return
new
Eltwise
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Ceil
:
return
new
Ceil
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Concat
:
return
new
Concat
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Fill
:
return
new
Fill
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Nhwc2Nchw
:
return
new
Nhwc2Nchw
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Nchw2Nhwc
:
return
new
Nchw2Nhwc
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Transpose
:
return
new
Transpose
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Slice
:
return
new
Slice
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Squeeze
:
return
new
Squeeze
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Nchw2Nhwc
:
return
new
Nchw2Nhwc
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Nhwc2Nchw
:
return
new
Nhwc2Nchw
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Flatten
:
return
new
Flatten
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Mean
:
...
...
@@ -521,8 +548,6 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
Maximum
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Minimum
:
return
new
Minimum
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Pad
:
return
new
Pad
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_StridedSlice
:
return
new
StridedSlice
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Prelu
:
...
...
@@ -559,12 +584,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
GreaterEqual
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Floor
:
return
new
Floor
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Ceil
:
return
new
Ceil
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Split
:
return
new
Split
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_OneHot
:
return
new
OneHot
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_PriorBox
:
return
new
PriorBox
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_SpaceToDepth
:
return
new
SpaceToDepth
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Tile
:
...
...
@@ -591,7 +616,29 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return
new
Shape
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Unsqueeze
:
return
new
Unsqueeze
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_BatchToSpace
:
return
new
BatchToSpace
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_SpaceToBatch
:
return
new
SpaceToBatch
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_BroadcastTo
:
return
new
BroadcastTo
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_DepthToSpace
:
return
new
DepthToSpace
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Lstm
:
return
new
Lstm
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ZerosLike
:
return
new
ZerosLike
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_MakeTuple
:
return
new
MakeTuple
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Where
:
return
new
Where
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ScatterND
:
return
new
ScatterND
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ConstantOfShape
:
return
new
ConstantOfShape
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
default:
MS_LOG
(
ERROR
)
<<
"Unsupported primitive type in UnPackFromSchemaPrimitive : "
<<
schema
::
EnumNamePrimitiveType
(
op_type
);
break
;
}
return
nullptr
;
...
...
mindspore/lite/src/ops/reshape.cc
浏览文件 @
80d570f0
...
...
@@ -25,10 +25,10 @@ namespace mindspore {
namespace
lite
{
#ifdef PRIMITIVE_WRITEABLE
int
Reshape
::
GetFormat
()
const
{
return
this
->
primitive_
->
value
.
AsReshape
()
->
format
;
}
std
::
vector
<
long
>
Reshape
::
GetShape
()
const
{
return
this
->
primitive_
->
value
.
AsReshape
()
->
shape
;
}
std
::
vector
<
int64_t
>
Reshape
::
GetShape
()
const
{
return
this
->
primitive_
->
value
.
AsReshape
()
->
shape
;
}
void
Reshape
::
SetFormat
(
int
format
)
{
this
->
primitive_
->
value
.
AsReshape
()
->
format
=
(
schema
::
Format
)
format
;
}
void
Reshape
::
SetShape
(
const
std
::
vector
<
long
>
&
shape
)
{
this
->
primitive_
->
value
.
AsReshape
()
->
shape
=
shape
;
}
void
Reshape
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
shape
)
{
this
->
primitive_
->
value
.
AsReshape
()
->
shape
=
shape
;
}
int
Reshape
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
this
->
primitive_
=
new
(
schema
::
PrimitiveT
);
auto
attr
=
std
::
make_unique
<
schema
::
ReshapeT
>
();
...
...
@@ -59,13 +59,13 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
#else
int
Reshape
::
GetFormat
()
const
{
return
this
->
primitive_
->
value_as_Reshape
()
->
format
();
}
std
::
vector
<
long
>
Reshape
::
GetShape
()
const
{
std
::
vector
<
int64_t
>
Reshape
::
GetShape
()
const
{
auto
fb_vector
=
this
->
primitive_
->
value_as_Reshape
()
->
shape
();
return
std
::
vector
<
long
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
return
std
::
vector
<
int64_t
>
(
fb_vector
->
begin
(),
fb_vector
->
end
());
}
void
Reshape
::
SetFormat
(
int
format
)
{}
void
Reshape
::
SetShape
(
const
std
::
vector
<
long
>
&
shape
)
{}
void
Reshape
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
shape
)
{}
#endif
int
Reshape
::
CalNewShape
(
const
tensor
::
Tensor
*
in_tensor
,
std
::
vector
<
int
>
*
out_shape
)
const
{
...
...
mindspore/lite/src/ops/reshape.h
浏览文件 @
80d570f0
...
...
@@ -36,9 +36,9 @@ class Reshape : public PrimitiveC {
#endif
int
InferShape
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
)
override
;
int
GetFormat
()
const
;
std
::
vector
<
long
>
GetShape
()
const
;
std
::
vector
<
int64_t
>
GetShape
()
const
;
void
SetFormat
(
int
format
);
void
SetShape
(
const
std
::
vector
<
long
>
&
shape
);
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
shape
);
private:
int
CalNewShape
(
const
lite
::
tensor
::
Tensor
*
in_tensor
,
std
::
vector
<
int
>
*
out_shape
)
const
;
...
...
mindspore/lite/src/ops/resize.cc
浏览文件 @
80d570f0
...
...
@@ -21,15 +21,15 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int
Resize
::
GetFormat
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
format
;
}
int
Resize
::
GetMethod
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
method
;
}
long
Resize
::
GetNewHeight
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
newHeight
;
}
long
Resize
::
GetNewWidth
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
newWidth
;
}
int64_t
Resize
::
GetNewHeight
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
newHeight
;
}
int64_t
Resize
::
GetNewWidth
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
newWidth
;
}
bool
Resize
::
GetAlignCorners
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
alignCorners
;
}
bool
Resize
::
GetPreserveAspectRatio
()
const
{
return
this
->
primitive_
->
value
.
AsResize
()
->
preserveAspectRatio
;
}
void
Resize
::
SetFormat
(
int
format
)
{
this
->
primitive_
->
value
.
AsResize
()
->
format
=
(
schema
::
Format
)
format
;
}
void
Resize
::
SetMethod
(
int
method
)
{
this
->
primitive_
->
value
.
AsResize
()
->
method
=
(
schema
::
ResizeMethod
)
method
;
}
void
Resize
::
SetNewHeight
(
long
new_height
)
{
this
->
primitive_
->
value
.
AsResize
()
->
newHeight
=
new_height
;
}
void
Resize
::
SetNewWidth
(
long
new_width
)
{
this
->
primitive_
->
value
.
AsResize
()
->
newWidth
=
new_width
;
}
void
Resize
::
SetNewHeight
(
int64_t
new_height
)
{
this
->
primitive_
->
value
.
AsResize
()
->
newHeight
=
new_height
;
}
void
Resize
::
SetNewWidth
(
int64_t
new_width
)
{
this
->
primitive_
->
value
.
AsResize
()
->
newWidth
=
new_width
;
}
void
Resize
::
SetAlignCorners
(
bool
align_corners
)
{
this
->
primitive_
->
value
.
AsResize
()
->
alignCorners
=
align_corners
;
}
void
Resize
::
SetPreserveAspectRatio
(
bool
preserve_aspect_ratio
)
{
this
->
primitive_
->
value
.
AsResize
()
->
preserveAspectRatio
=
preserve_aspect_ratio
;
...
...
@@ -39,15 +39,15 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
int
Resize
::
GetFormat
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
format
();
}
int
Resize
::
GetMethod
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
method
();
}
long
Resize
::
GetNewHeight
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
newHeight
();
}
long
Resize
::
GetNewWidth
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
newWidth
();
}
int64_t
Resize
::
GetNewHeight
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
newHeight
();
}
int64_t
Resize
::
GetNewWidth
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
newWidth
();
}
bool
Resize
::
GetAlignCorners
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
alignCorners
();
}
bool
Resize
::
GetPreserveAspectRatio
()
const
{
return
this
->
primitive_
->
value_as_Resize
()
->
preserveAspectRatio
();
}
void
Resize
::
SetFormat
(
int
format
)
{}
void
Resize
::
SetMethod
(
int
method
)
{}
void
Resize
::
SetNewHeight
(
long
new_height
)
{}
void
Resize
::
SetNewWidth
(
long
new_width
)
{}
void
Resize
::
SetNewHeight
(
int64_t
new_height
)
{}
void
Resize
::
SetNewWidth
(
int64_t
new_width
)
{}
void
Resize
::
SetAlignCorners
(
bool
align_corners
)
{}
void
Resize
::
SetPreserveAspectRatio
(
bool
preserve_aspect_ratio
)
{}
#endif
...
...
mindspore/lite/src/ops/resize.h
浏览文件 @
80d570f0
...
...
@@ -36,14 +36,14 @@ class Resize : public PrimitiveC {
int
InferShape
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
)
override
;
int
GetFormat
()
const
;
int
GetMethod
()
const
;
long
GetNewHeight
()
const
;
long
GetNewWidth
()
const
;
int64_t
GetNewHeight
()
const
;
int64_t
GetNewWidth
()
const
;
bool
GetAlignCorners
()
const
;
bool
GetPreserveAspectRatio
()
const
;
void
SetFormat
(
int
format
);
void
SetMethod
(
int
method
);
void
SetNewHeight
(
long
new_height
);
void
SetNewWidth
(
long
new_width
);
void
SetNewHeight
(
int64_t
new_height
);
void
SetNewWidth
(
int64_t
new_width
);
void
SetAlignCorners
(
bool
align_corners
);
void
SetPreserveAspectRatio
(
bool
preserve_aspect_ratio
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录