Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
aa52f08f
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
aa52f08f
编写于
10月 21, 2019
作者:
叶
叶剑武
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fallback from opencl to cpu in ReshapeOp
上级
6a231fdb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
29 addition
and
17 deletion
+29
-17
mace/ops/reshape.cc
mace/ops/reshape.cc
+29
-17
未找到文件。
mace/ops/reshape.cc
浏览文件 @
aa52f08f
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
#include "mace/utils/math.h"
#include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/reshape.h"
#include "mace/ops/opencl/buffer/reshape.h"
#include "mace/ops/opencl/buffer/reshape.h"
#include "mace/ops/opencl/image/reshape.h"
#endif
#endif
namespace
mace
{
namespace
mace
{
...
@@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input,
...
@@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input,
MACE_CHECK
(
shape_data
[
i
]
>=
0
,
"Shape must be non-negative: "
,
MACE_CHECK
(
shape_data
[
i
]
>=
0
,
"Shape must be non-negative: "
,
shape_data
[
i
]);
shape_data
[
i
]);
if
(
shape_data
[
i
]
==
0
)
{
if
(
shape_data
[
i
]
==
0
)
{
MACE_CHECK
(
i
<
input
->
dim_size
(),
MACE_CHECK
(
i
<
input
->
dim_size
(),
"dims:0 out of input dims' range."
);
"dims:0 out of input dims' range."
);
n
=
input
->
dim
(
i
);
n
=
input
->
dim
(
i
);
}
else
{
}
else
{
n
=
shape_data
[
i
];
n
=
shape_data
[
i
];
...
@@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input,
...
@@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input,
if
(
unknown_idx
!=
-
1
)
{
if
(
unknown_idx
!=
-
1
)
{
MACE_CHECK
(
product
!=
0
)
MACE_CHECK
(
product
!=
0
)
<<
"Cannot infer shape if there is zero shape size."
;
<<
"Cannot infer shape if there is zero shape size."
;
const
index_t
missing
=
input
->
size
()
/
product
;
const
index_t
missing
=
input
->
size
()
/
product
;
MACE_CHECK
(
missing
*
product
==
input
->
size
())
MACE_CHECK
(
missing
*
product
==
input
->
size
())
<<
"Input size not match reshaped tensor size"
;
<<
"Input size not match reshaped tensor size"
;
(
*
out_shape
)[
unknown_idx
]
=
missing
;
(
*
out_shape
)[
unknown_idx
]
=
missing
;
}
}
...
@@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input,
...
@@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input,
}
// namespace
}
// namespace
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ReshapeOp
:
public
Operation
{
class
ReshapeOp
:
public
Operation
{
public:
public:
explicit
ReshapeOp
(
OpConstructContext
*
context
)
explicit
ReshapeOp
(
OpConstructContext
*
context
)
...
@@ -90,11 +89,11 @@ class ReshapeOp : public Operation {
...
@@ -90,11 +89,11 @@ class ReshapeOp : public Operation {
GetOutputShape
(
input
,
shape_data
,
num_dims
,
&
out_shape
));
GetOutputShape
(
input
,
shape_data
,
num_dims
,
&
out_shape
));
// NHWC -> NCHW
// NHWC -> NCHW
if
(
has_df_
&&
D
==
DeviceType
::
CPU
if
(
has_df_
&&
D
==
DeviceType
::
CPU
&&
out_shape
.
size
()
==
4
&&
&&
out_shape
.
size
()
==
4
&&
shape
->
is_weight
())
{
shape
->
is_weight
())
{
std
::
vector
<
int
>
dst_dims
=
{
0
,
3
,
1
,
2
};
std
::
vector
<
int
>
dst_dims
=
{
0
,
3
,
1
,
2
};
std
::
vector
<
index_t
>
trans_shape
=
TransposeShape
<
index_t
,
index_t
>
(
std
::
vector
<
index_t
>
trans_shape
=
out_shape
,
dst_dims
);
TransposeShape
<
index_t
,
index_t
>
(
out_shape
,
dst_dims
);
out_shape
=
trans_shape
;
out_shape
=
trans_shape
;
}
}
...
@@ -114,12 +113,11 @@ class ReshapeOp : public Operation {
...
@@ -114,12 +113,11 @@ class ReshapeOp : public Operation {
};
};
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
template
<
>
template
<
>
class
ReshapeOp
<
GPU
,
float
>
:
public
Operation
{
class
ReshapeOp
<
GPU
,
float
>
:
public
Operation
{
public:
public:
explicit
ReshapeOp
(
OpConstructContext
*
context
)
explicit
ReshapeOp
(
OpConstructContext
*
context
)
:
Operation
(
context
),
:
Operation
(
context
),
dim_
(
Operation
::
GetRepeatedArgs
<
int
>
(
"dim"
))
{
dim_
(
Operation
::
GetRepeatedArgs
<
int
>
(
"dim"
))
{
if
(
context
->
GetOpMemoryType
()
==
MemoryType
::
GPU_IMAGE
)
{
if
(
context
->
GetOpMemoryType
()
==
MemoryType
::
GPU_IMAGE
)
{
kernel_
=
make_unique
<
opencl
::
image
::
ReshapeKernel
>
(
context
);
kernel_
=
make_unique
<
opencl
::
image
::
ReshapeKernel
>
(
context
);
}
else
{
}
else
{
...
@@ -148,11 +146,25 @@ class ReshapeOp<GPU, float> : public Operation {
...
@@ -148,11 +146,25 @@ class ReshapeOp<GPU, float> : public Operation {
#endif
#endif
void
RegisterReshape
(
OpRegistryBase
*
op_registry
)
{
void
RegisterReshape
(
OpRegistryBase
*
op_registry
)
{
MACE_REGISTER_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
,
MACE_REGISTER_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
,
DeviceType
::
CPU
,
float
);
DeviceType
::
CPU
,
float
);
MACE_REGISTER_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
,
DeviceType
::
CPU
,
int32_t
);
MACE_REGISTER_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
,
DeviceType
::
CPU
,
int32_t
);
MACE_REGISTER_GPU_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
);
MACE_REGISTER_GPU_OP
(
op_registry
,
"Reshape"
,
ReshapeOp
);
MACE_REGISTER_OP_CONDITION
(
op_registry
,
OpConditionBuilder
(
"Reshape"
).
SetDevicePlacerFunc
(
[](
OpConditionContext
*
context
)
->
std
::
set
<
DeviceType
>
{
auto
op
=
context
->
operator_def
();
if
(
op
->
output_shape_size
()
!=
op
->
output_size
())
{
return
{
DeviceType
::
CPU
,
DeviceType
::
GPU
};
}
auto
tensor_shape_info
=
context
->
tensor_shape_info
();
const
std
::
string
&
input_0
=
op
->
input
(
0
);
if
(
4
==
op
->
output_shape
(
0
).
dims_size
()
&&
4
==
tensor_shape_info
->
at
(
input_0
).
size
())
{
return
{
DeviceType
::
CPU
,
DeviceType
::
GPU
};
}
return
{
DeviceType
::
CPU
};
}));
}
}
}
// namespace ops
}
// namespace ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录