Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
冰之2023
Mace
提交
f6006d5e
Mace
项目概览
冰之2023
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f6006d5e
编写于
12月 04, 2017
作者:
L
liuqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change the axis tensor of concat to an attribute.
上级
4743a1e6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
45 addition
and
32 deletion
+45
-32
mace/kernels/concat.h
mace/kernels/concat.h
+16
-8
mace/kernels/opencl/cl/concat.cl
mace/kernels/opencl/cl/concat.cl
+0
-2
mace/kernels/opencl/concat.cc
mace/kernels/opencl/concat.cc
+4
-5
mace/ops/BUILD
mace/ops/BUILD
+16
-0
mace/ops/concat.h
mace/ops/concat.h
+5
-9
mace/ops/concat_test.cc
mace/ops/concat_test.cc
+4
-8
未找到文件。
mace/kernels/concat.h
浏览文件 @
f6006d5e
...
...
@@ -13,17 +13,24 @@
namespace
mace
{
namespace
kernels
{
struct
ConcatFunctorBase
{
ConcatFunctorBase
(
const
int32_t
axis
)
:
axis_
(
axis
){}
int32_t
axis_
;
};
template
<
DeviceType
D
,
typename
T
>
struct
ConcatFunctor
{
struct
ConcatFunctor
:
ConcatFunctorBase
{
ConcatFunctor
(
const
int32_t
axis
)
:
ConcatFunctorBase
(
axis
){}
void
operator
()(
const
std
::
vector
<
const
Tensor
*>
&
input_list
,
const
int32_t
axis
,
Tensor
*
output
)
{
const
Tensor
*
input0
=
input_list
.
front
();
const
int
inputs_count
=
input_list
.
size
()
-
1
;
const
int
inputs_count
=
input_list
.
size
();
std
::
vector
<
index_t
>
output_shape
(
input0
->
shape
());
index_t
inner_size
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
for
(
int
i
=
0
;
i
<
axis
_
;
++
i
)
{
inner_size
*=
output_shape
[
i
];
}
std
::
vector
<
index_t
>
outer_sizes
(
inputs_count
,
0
);
...
...
@@ -33,14 +40,14 @@ struct ConcatFunctor {
MACE_CHECK
(
input
->
dim_size
()
==
input0
->
dim_size
(),
"Ranks of all input tensors must be same."
);
for
(
int
j
=
0
;
j
<
input
->
dim_size
();
++
j
)
{
if
(
j
==
axis
)
{
if
(
j
==
axis
_
)
{
continue
;
}
MACE_CHECK
(
input
->
dim
(
j
)
==
input0
->
dim
(
j
),
"Dimensions of inputs should equal except axis."
);
}
outer_sizes
[
i
]
=
input
->
size
()
/
inner_size
;
output_shape
[
axis
]
+=
input
->
dim
(
axis
);
output_shape
[
axis
_
]
+=
input
->
dim
(
axis_
);
}
output
->
Resize
(
output_shape
);
...
...
@@ -67,9 +74,10 @@ struct ConcatFunctor {
};
template
<
typename
T
>
struct
ConcatFunctor
<
DeviceType
::
OPENCL
,
T
>
{
struct
ConcatFunctor
<
DeviceType
::
OPENCL
,
T
>
:
ConcatFunctorBase
{
ConcatFunctor
(
const
int32_t
axis
)
:
ConcatFunctorBase
(
axis
){}
void
operator
()(
const
std
::
vector
<
const
Tensor
*>
&
input_list
,
const
int32_t
axis
,
Tensor
*
output
);
};
...
...
mace/kernels/opencl/cl/concat.cl
浏览文件 @
f6006d5e
...
...
@@ -32,8 +32,6 @@ __kernel void concat_channel(__read_only image2d_t input0,
const
int
hb_idx
=
get_global_id
(
2
)
;
const
int
input0_chan_blk
=
(
input0_chan
+
3
)
/
4
;
const
sampler_t
SAMPLER
=
CLK_NORMALIZED_COORDS_FALSE
| CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST
;
DATA_TYPE4
data
=
0
;
#
ifdef
DIVISIBLE_FOUR
if
(
chan_blk_idx
+
1
<=
input0_chan_blk
)
{
...
...
mace/kernels/opencl/concat.cc
浏览文件 @
f6006d5e
...
...
@@ -60,10 +60,9 @@ static void Concat2(const Tensor *input0,
template
<
typename
T
>
void
ConcatFunctor
<
DeviceType
::
OPENCL
,
T
>::
operator
()(
const
std
::
vector
<
const
Tensor
*>
&
input_list
,
const
int32_t
axis
,
Tensor
*
output
)
{
const
int
inputs_count
=
input_list
.
size
()
-
1
;
MACE_CHECK
(
inputs_count
==
2
&&
axis
==
3
)
const
int
inputs_count
=
input_list
.
size
();
MACE_CHECK
(
inputs_count
==
2
&&
axis
_
==
3
)
<<
"Concat opencl kernel only support two elements with axis == 3"
;
const
Tensor
*
input0
=
input_list
[
0
];
...
...
@@ -74,13 +73,13 @@ void ConcatFunctor<DeviceType::OPENCL, T>::operator()(const std::vector<const Te
MACE_CHECK
(
input
->
dim_size
()
==
input0
->
dim_size
(),
"Ranks of all input tensors must be same."
);
for
(
int
j
=
0
;
j
<
input
->
dim_size
();
++
j
)
{
if
(
j
==
axis
)
{
if
(
j
==
axis
_
)
{
continue
;
}
MACE_CHECK
(
input
->
dim
(
j
)
==
input0
->
dim
(
j
),
"Dimensions of inputs should equal except axis."
);
}
output_shape
[
axis
]
+=
input
->
dim
(
axis
);
output_shape
[
axis
_
]
+=
input
->
dim
(
axis_
);
}
std
::
vector
<
size_t
>
image_shape
;
CalImage2DShape
(
output_shape
,
BufferType
::
IN_OUT
,
image_shape
);
...
...
mace/ops/BUILD
浏览文件 @
f6006d5e
...
...
@@ -61,6 +61,22 @@ cc_test(
],
)
cc_test
(
name
=
"concat_test"
,
testonly
=
1
,
srcs
=
glob
(
[
"concat_test.cc"
],
),
copts
=
[
"-std=c++11"
],
linkopts
=
[
"-fopenmp"
],
linkstatic
=
1
,
deps
=
[
":ops"
,
":test"
,
"@gtest//:gtest_main"
,
],
)
cc_test
(
name
=
"ops_benchmark"
,
testonly
=
1
,
...
...
mace/ops/concat.h
浏览文件 @
f6006d5e
...
...
@@ -14,17 +14,13 @@ template <DeviceType D, typename T>
class
ConcatOp
:
public
Operator
<
D
,
T
>
{
public:
ConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
op_def
,
ws
)
{}
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
functor_
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"axis"
,
3
)){}
bool
Run
()
override
{
const
int32_t
inputs_count
=
this
->
InputSize
()
-
1
;
MACE_CHECK
(
this
->
InputSize
()
>=
2
)
<<
"There must be at least two inputs to concat"
;
const
std
::
vector
<
const
Tensor
*>
input_list
=
this
->
Inputs
();
const
Tensor
*
axis_tensor
=
this
->
Input
(
inputs_count
);
MACE_CHECK
(
axis_tensor
->
dim_size
()
==
0
,
"axis should be a scalar integer, but got shape: "
,
axis_tensor
->
dim_size
());
Tensor
::
MappingGuard
axis_mapper
(
axis_tensor
);
const
int32_t
concat_axis
=
*
(
axis_tensor
->
data
<
int32_t
>
());
const
int32_t
concat_axis
=
OperatorBase
::
GetSingleArgument
<
int
>
(
"axis"
,
3
);
const
int32_t
input_dims
=
input_list
[
0
]
->
dim_size
();
const
int32_t
axis
=
concat_axis
<
0
?
concat_axis
+
input_dims
:
concat_axis
;
...
...
@@ -34,7 +30,7 @@ class ConcatOp : public Operator<D, T> {
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
functor_
(
input_list
,
axis
,
output
);
functor_
(
input_list
,
output
);
return
true
;
}
...
...
mace/ops/concat_test.cc
浏览文件 @
f6006d5e
...
...
@@ -16,7 +16,7 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) {
OpDefBuilder
(
"Concat"
,
"ConcatTest"
)
.
Input
(
"Input0"
)
.
Input
(
"Input1"
)
.
Input
(
"Axis"
)
.
AddIntArg
(
"axis"
,
0
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
...
...
@@ -28,7 +28,6 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) {
// Add inputs
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
(
"Input0"
,
input_shape
,
input0
);
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
(
"Input1"
,
input_shape
,
input1
);
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
int
>
(
"Axis"
,
{},
{
0
});
// Run
net
.
RunOp
();
...
...
@@ -54,7 +53,7 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
OpDefBuilder
(
"Concat"
,
"ConcatTest"
)
.
Input
(
"Input0"
)
.
Input
(
"Input1"
)
.
Input
(
"Axis"
)
.
AddIntArg
(
"axis"
,
1
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
...
...
@@ -66,7 +65,6 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
// Add inputs
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
(
"Input0"
,
input_shape
,
input0
);
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
(
"Input1"
,
input_shape
,
input1
);
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
int
>
(
"Axis"
,
{},
{
1
});
// Run
net
.
RunOp
();
...
...
@@ -99,7 +97,7 @@ TEST_F(ConcatOpTest, CPURandom) {
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
builder
=
builder
.
Input
((
"Input"
+
ToString
(
i
)).
c_str
());
}
builder
.
Input
(
"Axis"
).
Output
(
"Output"
).
Finalize
(
net
.
NewOperatorDef
());
builder
.
AddIntArg
(
"axis"
,
axis
).
Output
(
"Output"
).
Finalize
(
net
.
NewOperatorDef
());
std
::
vector
<
index_t
>
shape_data
;
GenerateRandomIntTypeData
<
index_t
>
({
dim
},
shape_data
,
1
,
dim
);
...
...
@@ -115,7 +113,6 @@ TEST_F(ConcatOpTest, CPURandom) {
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
((
"Input"
+
ToString
(
i
)).
c_str
(),
input_shapes
[
i
],
inputs
[
i
]);
}
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
int
>
(
"Axis"
,
{},
{
axis
});
// Run
net
.
RunOp
();
...
...
@@ -156,14 +153,13 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
shapes
[
i
]);
BufferToImage
<
DeviceType
::
OPENCL
,
T
>
(
net
,
input_name
,
image_name
,
kernels
::
BufferType
::
IN_OUT
);
}
net
.
AddInputFromArray
<
DeviceType
::
OPENCL
,
int
>
(
"Axis"
,
{},
{
axis
});
auto
builder
=
OpDefBuilder
(
"Concat"
,
"ConcatTest"
);
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
const
std
::
string
image_name
=
(
"InputImage"
+
ToString
(
i
)).
c_str
();
builder
=
builder
.
Input
(
image_name
);
}
builder
.
Input
(
"Axis"
)
builder
.
AddIntArg
(
"axis"
,
axis
)
.
Output
(
"OutputImage"
)
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录