Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
8cee5c3d
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8cee5c3d
编写于
6月 13, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make slice more general
上级
5b12c75f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
137 addition
and
64 deletion
+137
-64
mace/kernels/strided_slice.h
mace/kernels/strided_slice.h
+29
-2
mace/ops/strided_slice.h
mace/ops/strided_slice.h
+7
-2
mace/ops/strided_slice_test.cc
mace/ops/strided_slice_test.cc
+99
-59
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+2
-1
未找到文件。
mace/kernels/strided_slice.h
浏览文件 @
8cee5c3d
...
...
@@ -32,12 +32,16 @@ struct StridedSliceFunctor {
int
end_mask
,
int
ellipsis_mask
,
int
new_axis_mask
,
int
shrink_axis_mask
)
int
shrink_axis_mask
,
bool
is_slice
=
false
)
:
begin_mask_
(
begin_mask
),
end_mask_
(
end_mask
),
ellipsis_mask_
(
ellipsis_mask
),
new_axis_mask_
(
new_axis_mask
),
shrink_axis_mask_
(
shrink_axis_mask
)
{}
shrink_axis_mask_
(
shrink_axis_mask
),
is_slice_
(
is_slice
),
tmp_strides_tensor_
(
GetDeviceAllocator
(
D
),
DataTypeToEnum
<
int32_t
>::
v
())
{}
MaceStatus
operator
()(
const
Tensor
*
input
,
const
Tensor
*
begin_indices
,
...
...
@@ -49,6 +53,14 @@ struct StridedSliceFunctor {
MACE_CHECK
(
ellipsis_mask_
==
0
&&
new_axis_mask_
==
0
,
"ellipsis_mask and new_axis_mask are not supported yet."
);
if
(
strides
==
nullptr
)
{
tmp_strides_tensor_
.
Resize
({
begin_indices
->
size
()});
Tensor
::
MappingGuard
strides_guard
(
&
tmp_strides_tensor_
);
int32_t
*
strides_data
=
tmp_strides_tensor_
.
mutable_data
<
int32_t
>
();
std
::
fill
(
strides_data
,
strides_data
+
tmp_strides_tensor_
.
size
(),
1
);
strides
=
&
tmp_strides_tensor_
;
}
Tensor
::
MappingGuard
input_guard
(
input
);
Tensor
::
MappingGuard
begin_indices_guard
(
begin_indices
);
Tensor
::
MappingGuard
end_indices_guard
(
end_indices
);
...
...
@@ -56,6 +68,19 @@ struct StridedSliceFunctor {
const
T
*
input_data
=
input
->
data
<
T
>
();
const
int32_t
*
begin_indices_data
=
begin_indices
->
data
<
int32_t
>
();
const
int32_t
*
end_indices_data
=
end_indices
->
data
<
int32_t
>
();
std
::
vector
<
int32_t
>
slice_end_data
;
if
(
is_slice_
)
{
// if this op is slice, the end_indices_data is size actually
slice_end_data
.
resize
(
end_indices
->
size
());
for
(
int
i
=
0
;
i
<
slice_end_data
.
size
();
++
i
)
{
if
(
end_indices_data
[
i
]
==
-
1
)
{
slice_end_data
[
i
]
=
input
->
dim
(
i
);
}
else
{
slice_end_data
[
i
]
=
begin_indices_data
[
i
]
+
end_indices_data
[
i
];
}
}
end_indices_data
=
slice_end_data
.
data
();
}
const
int32_t
*
strides_data
=
strides
->
data
<
int32_t
>
();
std
::
vector
<
index_t
>
output_shape
;
...
...
@@ -152,6 +177,8 @@ struct StridedSliceFunctor {
int
ellipsis_mask_
;
int
new_axis_mask_
;
int
shrink_axis_mask_
;
bool
is_slice_
;
Tensor
tmp_strides_tensor_
;
};
}
// namespace kernels
...
...
mace/ops/strided_slice.h
浏览文件 @
8cee5c3d
...
...
@@ -30,13 +30,18 @@ class StridedSliceOp : public Operator<D, T> {
OperatorBase
::
GetOptionalArg
<
int
>
(
"end_mask"
,
0
),
OperatorBase
::
GetOptionalArg
<
int
>
(
"ellipsis_mask"
,
0
),
OperatorBase
::
GetOptionalArg
<
int
>
(
"new_axis_mask"
,
0
),
OperatorBase
::
GetOptionalArg
<
int
>
(
"shrink_axis_mask"
,
0
))
{}
OperatorBase
::
GetOptionalArg
<
int
>
(
"shrink_axis_mask"
,
0
),
OperatorBase
::
GetOptionalArg
<
bool
>
(
"slice"
,
false
))
{}
MaceStatus
Run
(
StatsFuture
*
future
)
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
begin_indices
=
this
->
Input
(
BEGIN
);
const
Tensor
*
end_indices
=
this
->
Input
(
END
);
const
Tensor
*
strides
=
this
->
Input
(
STRIDES
);
const
Tensor
*
strides
=
nullptr
;
if
(
this
->
InputSize
()
>
3
)
{
strides
=
this
->
Input
(
STRIDES
);
}
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
return
functor_
(
input
,
begin_indices
,
end_indices
,
strides
,
output
,
future
);
...
...
mace/ops/strided_slice_test.cc
浏览文件 @
8cee5c3d
...
...
@@ -23,32 +23,27 @@ class StridedSliceOpTest : public OpsTestBase {};
namespace
{
void
TestSlice
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
const
std
::
vector
<
int32_t
>
&
end_indices
,
const
std
::
vector
<
int32_t
>
&
strides
,
const
int
begin_mask
,
const
int
end_mask
,
const
int
ellipsis_mask
,
const
int
new_axis_mask
,
const
int
shrink_axis_mask
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
void
TestS
tridedS
lice
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
const
std
::
vector
<
int32_t
>
&
end_indices
,
const
std
::
vector
<
int32_t
>
&
strides
,
const
int
begin_mask
,
const
int
end_mask
,
const
int
ellipsis_mask
,
const
int
new_axis_mask
,
const
int
shrink_axis_mask
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
end_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
strides
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
end_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
strides
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"Input"
)
...
...
@@ -70,47 +65,92 @@ void TestSlice(const std::vector<index_t> &input_shape,
*
net
.
GetOutput
(
"Output"
));
}
void
TestSlice
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
const
std
::
vector
<
int32_t
>
&
indices_size
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"IndicesSize"
,
{
static_cast
<
int32_t
>
(
indices_size
.
size
())},
indices_size
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"Input"
)
.
Input
(
"BeginIndices"
)
.
Input
(
"IndicesSize"
)
.
Output
(
"Output"
)
.
AddIntArg
(
"slice"
,
1
)
.
Finalize
(
net
.
NewOperatorDef
());
net
.
RunOp
();
net
.
AddInputFromArray
<
CPU
,
float
>
(
"ExpectedOutput"
,
output_shape
,
output
);
ExpectTensorNear
<
float
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
*
net
.
GetOutput
(
"Output"
));
}
}
// namespace
TEST_F
(
StridedSliceOpTest
,
TestSliceByFirstAxis
)
{
TestSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
0
,
0
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
TestSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
0
,
0
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
0
,
0
,
0
,
0
,
1
,
{
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
TestSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
1
,
2
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
6
,
6
,
0
,
0
,
0
,
{
1
,
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceByFirstAxis
)
{
TestStridedSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
0
,
0
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
TestStridedSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
0
,
0
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
0
,
0
,
0
,
0
,
1
,
{
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
TestStridedSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
1
,
1
,
2
},
{
2
,
3
,
2
},
{
1
,
1
,
1
},
6
,
6
,
0
,
0
,
0
,
{
1
,
3
,
2
},
{
7
,
8
,
9
,
10
,
11
,
12
});
}
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceRank1
)
{
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
1
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
2
,
3
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
3
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
2
,
3
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
3
,
2
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
1
},
{
-
4
},
{
-
2
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
4
,
2
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
1
,
0
,
0
,
0
,
0
,
{
3
},
{
4
,
3
,
2
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
0
,
1
,
0
,
0
,
0
,
{
3
},
{
3
,
2
,
1
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
1
,
1
,
0
,
0
,
0
,
{
4
},
{
4
,
3
,
2
,
1
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
2
},
{
4
},
{
2
},
1
,
1
,
0
,
0
,
0
,
{
2
},
{
1
,
3
});
TestStridedSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
2
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
1
,
{},
{
3
});
}
TEST_F
(
StridedSliceOpTest
,
TestSliceRank1
)
{
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
1
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
2
,
3
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
3
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
2
,
3
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
3
,
2
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
1
},
{
-
4
},
{
-
2
},
0
,
0
,
0
,
0
,
0
,
{
2
},
{
4
,
2
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
1
,
0
,
0
,
0
,
0
,
{
3
},
{
4
,
3
,
2
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
0
,
1
,
0
,
0
,
0
,
{
3
},
{
3
,
2
,
1
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
-
2
},
{
-
4
},
{
-
1
},
1
,
1
,
0
,
0
,
0
,
{
4
},
{
4
,
3
,
2
,
1
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
2
},
{
4
},
{
2
},
1
,
1
,
0
,
0
,
0
,
{
2
},
{
1
,
3
});
TestSlice
({
4
},
{
1
,
2
,
3
,
4
},
{
2
},
{
3
},
{
1
},
0
,
0
,
0
,
0
,
1
,
{},
{
3
});
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceRank2
)
{
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
1
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
},
{
5
,
6
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
1
,
2
},
0
,
0
,
0
,
0
,
0
,
{
2
,
2
},
{
1
,
3
,
4
,
6
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
0
,
0
},
{
-
1
,
-
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
},
{
6
,
5
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
0
,
0
},
{
-
1
,
-
1
},
3
,
3
,
0
,
0
,
0
,
{
2
,
3
},
{
6
,
5
,
4
,
3
,
2
,
1
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
0
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
1
,
{
3
},
{
4
,
5
,
6
});
TestStridedSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
3
,
{},
{
6
});
}
TEST_F
(
StridedSliceOpTest
,
TestSliceRank2
)
{
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
1
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
},
{
5
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
1
,
2
},
0
,
0
,
0
,
0
,
0
,
{
2
,
2
},
{
1
,
3
,
4
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
0
,
0
},
{
-
1
,
-
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
},
{
6
,
5
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
0
,
0
},
{
-
1
,
-
1
},
3
,
3
,
0
,
0
,
0
,
{
2
,
3
},
{
6
,
5
,
4
,
3
,
2
,
1
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
0
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
1
,
{
3
},
{
4
,
5
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
},
{
2
,
3
},
{
1
,
1
},
0
,
0
,
0
,
0
,
3
,
{},
{
6
});
TEST_F
(
StridedSliceOpTest
,
TestSlice
)
{
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
0
},
{
1
,
2
},
{
1
,
2
},
{
4
,
5
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
-
1
},
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
1
},
{
2
,
-
1
},
{
2
,
2
},
{
2
,
3
,
5
,
6
});
}
}
// namespace test
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
8cee5c3d
...
...
@@ -1155,5 +1155,6 @@ class Transformer(base_converter.ConverterInterface):
print
(
"Final ops:"
)
for
op
in
net
.
op
:
print
(
"%s (%s)"
%
(
op
.
name
,
op
.
type
))
print
(
"%s (%s): %s"
%
(
op
.
name
,
op
.
type
,
[
out_shape
.
dims
for
out_shape
in
op
.
output_shape
]))
return
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录