Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
cbd381b0
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cbd381b0
编写于
3月 18, 2019
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into 'master'
Optimize reshape op See merge request !1018
上级
cdf5d030
56a81579
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
13 addition
and
19 deletion
+13
-19
mace/ops/reshape.cc
mace/ops/reshape.cc
+13
-19
未找到文件。
mace/ops/reshape.cc
浏览文件 @
cbd381b0
...
@@ -23,16 +23,12 @@ template <DeviceType D, class T>
...
@@ -23,16 +23,12 @@ template <DeviceType D, class T>
class
ReshapeOp
:
public
Operation
{
class
ReshapeOp
:
public
Operation
{
public:
public:
explicit
ReshapeOp
(
OpConstructContext
*
context
)
explicit
ReshapeOp
(
OpConstructContext
*
context
)
:
Operation
(
context
)
{}
:
Operation
(
context
),
has_df_
(
Operation
::
GetOptionalArg
<
int
>
(
"has_data_format"
,
0
))
{}
MaceStatus
Run
(
OpContext
*
context
)
override
{
MaceStatus
Run
(
OpContext
*
context
)
override
{
MACE_UNUSED
(
context
);
MACE_UNUSED
(
context
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
std
::
vector
<
index_t
>
&
input_shape
=
input
->
shape
();
int
axis
=
Operation
::
GetOptionalArg
<
int
>
(
"reshape_axis"
,
0
);
int
num_axes
=
Operation
::
GetOptionalArg
<
int
>
(
"num_axes"
,
-
1
);
MACE_CHECK
(
axis
==
0
&&
num_axes
==
-
1
,
"Only support axis = 0 and num_axes = -1"
);
const
Tensor
*
shape
=
this
->
Input
(
SHAPE
);
const
Tensor
*
shape
=
this
->
Input
(
SHAPE
);
const
index_t
num_dims
=
shape
->
dim_size
()
==
0
?
0
:
shape
->
dim
(
0
);
const
index_t
num_dims
=
shape
->
dim_size
()
==
0
?
0
:
shape
->
dim
(
0
);
Tensor
::
MappingGuard
shape_guard
(
shape
);
Tensor
::
MappingGuard
shape_guard
(
shape
);
...
@@ -40,20 +36,16 @@ class ReshapeOp : public Operation {
...
@@ -40,20 +36,16 @@ class ReshapeOp : public Operation {
int
unknown_idx
=
-
1
;
int
unknown_idx
=
-
1
;
index_t
product
=
1
;
index_t
product
=
1
;
std
::
vector
<
index_t
>
out_shape
;
std
::
vector
<
index_t
>
out_shape
(
num_dims
)
;
index_t
n
=
0
;
index_t
n
=
0
;
for
(
int
i
=
0
;
i
<
num_dims
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_dims
;
++
i
)
{
if
(
shape_data
[
i
]
==
-
1
)
{
if
(
shape_data
[
i
]
==
-
1
)
{
MACE_CHECK
(
unknown_idx
==
-
1
,
"Only one input size may be -1"
);
MACE_CHECK
(
unknown_idx
==
-
1
,
"Only one input size may be -1"
);
unknown_idx
=
i
;
unknown_idx
=
i
;
out_shape
.
push_back
(
1
);
out_shape
[
i
]
=
1
;
}
else
if
(
shape_data
[
i
]
==
0
)
{
MACE_CHECK
(
shape_data
[
i
]
==
0
,
"Shape should be 0"
);
out_shape
.
push_back
(
input_shape
[
i
]);
product
*=
input_shape
[
i
];
}
else
{
}
else
{
MACE_CHECK
(
shape_data
[
i
]
>
0
,
"Shape must be non-negative: "
,
MACE_CHECK
(
shape_data
[
i
]
>
=
0
,
"Shape must be non-negative: "
,
shape_data
[
i
]);
shape_data
[
i
]);
if
(
shape_data
[
i
]
==
0
)
{
if
(
shape_data
[
i
]
==
0
)
{
MACE_CHECK
(
i
<
input
->
dim_size
(),
MACE_CHECK
(
i
<
input
->
dim_size
(),
...
@@ -62,7 +54,7 @@ class ReshapeOp : public Operation {
...
@@ -62,7 +54,7 @@ class ReshapeOp : public Operation {
}
else
{
}
else
{
n
=
shape_data
[
i
];
n
=
shape_data
[
i
];
}
}
out_shape
.
push_back
(
n
)
;
out_shape
[
i
]
=
n
;
product
*=
n
;
product
*=
n
;
}
}
}
}
...
@@ -77,14 +69,13 @@ class ReshapeOp : public Operation {
...
@@ -77,14 +69,13 @@ class ReshapeOp : public Operation {
}
}
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
// NHWC -> NCHW
// NHWC -> NCHW
auto
has_df
=
Operation
::
GetOptionalArg
<
int
>
(
"has_data_format"
,
0
);
if
(
has_df_
&&
D
==
DeviceType
::
CPU
if
(
has_df
&&
D
==
DeviceType
::
CPU
&&
out_shape
.
size
()
==
4
&&
shape
->
is_weight
())
{
&&
out_shape
.
size
()
==
4
&&
shape
->
is_weight
())
{
std
::
vector
<
int
>
dst_dims
=
{
0
,
3
,
1
,
2
};
std
::
vector
<
int
>
dst_dims
=
{
0
,
3
,
1
,
2
};
std
::
vector
<
index_t
>
out_shape_gpu
=
TransposeShape
<
index_t
,
index_t
>
(
std
::
vector
<
index_t
>
trans_shape
=
TransposeShape
<
index_t
,
index_t
>
(
out_shape
,
dst_dims
);
out_shape
,
dst_dims
);
out_shape
=
out_shape_gpu
;
out_shape
=
trans_shape
;
}
}
output
->
ReuseTensorBuffer
(
*
input
);
output
->
ReuseTensorBuffer
(
*
input
);
...
@@ -93,6 +84,9 @@ class ReshapeOp : public Operation {
...
@@ -93,6 +84,9 @@ class ReshapeOp : public Operation {
return
MaceStatus
::
MACE_SUCCESS
;
return
MaceStatus
::
MACE_SUCCESS
;
}
}
private:
bool
has_df_
;
private:
private:
MACE_OP_INPUT_TAGS
(
INPUT
,
SHAPE
);
MACE_OP_INPUT_TAGS
(
INPUT
,
SHAPE
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录