Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Mr.Vain
Mace
提交
a14a6cb4
Mace
项目概览
Mr.Vain
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a14a6cb4
编写于
3月 22, 2019
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor strided_slice and fix bound check, update cumsum impl
上级
a85f052d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
535 addition
and
218 deletion
+535
-218
mace/ops/cumsum.cc
mace/ops/cumsum.cc
+66
-80
mace/ops/cumsum_benchmark.cc
mace/ops/cumsum_benchmark.cc
+90
-0
mace/ops/cumsum_test.cc
mace/ops/cumsum_test.cc
+50
-17
mace/ops/strided_slice.cc
mace/ops/strided_slice.cc
+175
-121
mace/ops/strided_slice_test.cc
mace/ops/strided_slice_test.cc
+154
-0
未找到文件。
mace/ops/cumsum.cc
浏览文件 @
a14a6cb4
...
...
@@ -19,26 +19,6 @@
namespace
mace
{
namespace
ops
{
namespace
{
void
PlusOne
(
int
*
val
)
{
++
(
*
val
);
}
void
SubOne
(
int
*
val
)
{
--
(
*
val
);
}
bool
LessThan
(
const
int
&
val
,
const
int
&
boundary
)
{
return
val
<
boundary
;
}
bool
NotLessThanZero
(
const
int
&
val
,
const
int
&
boundary
)
{
MACE_UNUSED
(
boundary
);
return
val
>=
0
;
}
}
// namespace
template
<
DeviceType
D
,
typename
T
>
class
CumsumOp
;
...
...
@@ -47,9 +27,10 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
public:
explicit
CumsumOp
(
OpConstructContext
*
context
)
:
Operation
(
context
),
axis_
(
Operation
::
GetOptionalArg
<
int
>
(
"axis"
,
3
)),
axis_
(
Operation
::
GetOptionalArg
<
int
>
(
"axis"
,
0
)),
exclusive_
(
Operation
::
GetOptionalArg
<
bool
>
(
"exclusive"
,
false
)),
reverse_
(
Operation
::
GetOptionalArg
<
bool
>
(
"reverse"
,
false
))
{}
reverse_
(
Operation
::
GetOptionalArg
<
bool
>
(
"reverse"
,
false
)),
checked_
(
false
)
{}
void
Validate
()
{
const
int32_t
input_dims
=
this
->
Input
(
0
)
->
dim_size
();
...
...
@@ -64,9 +45,9 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
MACE_UNUSED
(
context
);
if
(
!
checked_
)
{
Validate
();
auto
df
=
static_cast
<
DataFormat
>
(
Operation
::
GetOptionalArg
<
int
>
(
"
data_format"
,
DataFormat
::
DF_NONE
)
);
if
(
df
==
DataFormat
::
NHWC
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
bool
has_data_format
=
Operation
::
GetOptionalArg
<
int
>
(
"
has_data_format"
,
0
);
if
(
has_data_format
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
if
(
axis_
==
3
)
axis_
=
1
;
else
if
(
axis_
==
2
)
axis_
=
3
;
else
if
(
axis_
==
1
)
axis_
=
2
;
...
...
@@ -75,6 +56,7 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
}
const
Tensor
*
input
=
this
->
Input
(
0
);
const
std
::
vector
<
index_t
>
input_shape
=
input
->
shape
();
Tensor
*
output
=
this
->
Output
(
0
);
MACE_RETURN_IF_ERROR
(
output
->
ResizeLike
(
input
));
...
...
@@ -85,66 +67,70 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
const
float
*
input_ptr
=
input
->
data
<
float
>
();
float
*
output_ptr
=
output
->
mutable_data
<
float
>
();
std
::
function
<
void
(
int
*
)
>
next
=
reverse_
?
SubOne
:
PlusOne
;
std
::
function
<
void
(
int
*
)
>
previous
=
reverse_
?
PlusOne
:
SubOne
;
std
::
function
<
bool
(
const
int
&
,
const
int
&
)
>
boundary
=
reverse_
?
NotLessThanZero
:
LessThan
;
if
(
input
->
dim_size
()
==
4
)
{
const
int
batch
=
input
->
dim
(
0
);
const
int
channel
=
input
->
dim
(
1
);
const
int
height
=
input
->
dim
(
2
);
const
int
width
=
input
->
dim
(
3
);
const
int
axis_dim_size
=
input
->
dim
(
axis_
);
for
(
int
n
=
reverse_
?
batch
-
1
:
0
;
boundary
(
n
,
batch
);
next
(
&
n
))
{
for
(
int
c
=
reverse_
?
channel
-
1
:
0
;
boundary
(
c
,
channel
);
next
(
&
c
))
{
for
(
int
h
=
reverse_
?
height
-
1
:
0
;
boundary
(
h
,
height
);
next
(
&
h
))
{
for
(
int
w
=
reverse_
?
width
-
1
:
0
;
boundary
(
w
,
width
);
next
(
&
w
))
{
int
dims
[
4
]
=
{
n
,
c
,
h
,
w
};
if
(
!
reverse_
&&
dims
[
axis_
]
==
0
)
{
if
(
exclusive_
)
{
output_ptr
[((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
]
=
0
;
}
else
{
continue
;
}
}
else
if
(
reverse_
&&
dims
[
axis_
]
==
axis_dim_size
-
1
)
{
if
(
exclusive_
)
{
output_ptr
[((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
]
=
0
;
}
else
{
continue
;
}
}
else
{
previous
(
&
dims
[
axis_
]);
if
(
exclusive_
)
{
output_ptr
[((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
]
=
input_ptr
[((
dims
[
0
]
*
channel
+
dims
[
1
])
*
height
+
dims
[
2
])
*
width
+
dims
[
3
]]
+
output_ptr
[((
dims
[
0
]
*
channel
+
dims
[
1
])
*
height
+
dims
[
2
])
*
width
+
dims
[
3
]];
}
else
{
output_ptr
[((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
]
=
input_ptr
[((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
]
+
output_ptr
[((
dims
[
0
]
*
channel
+
dims
[
1
])
*
height
+
dims
[
2
])
*
width
+
dims
[
3
]];
}
}
const
index_t
outer_size
=
std
::
accumulate
(
input_shape
.
begin
(),
input_shape
.
begin
()
+
axis_
,
1
,
std
::
multiplies
<
index_t
>
());
const
index_t
inner_size
=
std
::
accumulate
(
input_shape
.
begin
()
+
axis_
+
1
,
input_shape
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
const
index_t
cum_size
=
input_shape
[
axis_
];
if
(
!
reverse_
)
{
#pragma omp parallel for
for
(
index_t
outer_idx
=
0
;
outer_idx
<
outer_size
;
++
outer_idx
)
{
index_t
start_idx
=
outer_idx
*
cum_size
*
inner_size
;
for
(
index_t
cum_idx
=
0
;
cum_idx
<
cum_size
;
++
cum_idx
)
{
if
(
cum_idx
==
0
)
{
if
(
exclusive_
)
{
std
::
memset
(
output_ptr
+
start_idx
,
0
,
sizeof
(
T
)
*
inner_size
);
}
else
{
std
::
memcpy
(
output_ptr
+
start_idx
,
input_ptr
+
start_idx
,
sizeof
(
T
)
*
inner_size
);
}
}
else
{
index_t
cur_idx
=
start_idx
+
cum_idx
*
inner_size
;
index_t
pre_idx
=
start_idx
+
(
cum_idx
-
1
)
*
inner_size
;
index_t
input_idx
=
exclusive_
?
pre_idx
:
cur_idx
;
for
(
index_t
inner_idx
=
0
;
inner_idx
<
inner_size
;
++
inner_idx
)
{
output_ptr
[
cur_idx
+
inner_idx
]
=
output_ptr
[
pre_idx
+
inner_idx
]
+
input_ptr
[
input_idx
+
inner_idx
];
}
}
}
}
}
else
{
MACE_NOT_IMPLEMENTED
;
#pragma omp parallel for
for
(
index_t
outer_idx
=
outer_size
-
1
;
outer_idx
>=
0
;
--
outer_idx
)
{
index_t
start_idx
=
outer_idx
*
cum_size
*
inner_size
;
for
(
index_t
cum_idx
=
cum_size
-
1
;
cum_idx
>=
0
;
--
cum_idx
)
{
index_t
cur_idx
=
start_idx
+
cum_idx
*
inner_size
;
if
(
cum_idx
==
cum_size
-
1
)
{
if
(
exclusive_
)
{
std
::
memset
(
output_ptr
+
cur_idx
,
0
,
sizeof
(
T
)
*
inner_size
);
}
else
{
std
::
memcpy
(
output_ptr
+
cur_idx
,
input_ptr
+
cur_idx
,
sizeof
(
T
)
*
inner_size
);
}
}
else
{
index_t
pre_idx
=
start_idx
+
(
cum_idx
+
1
)
*
inner_size
;
index_t
input_idx
=
exclusive_
?
pre_idx
:
cur_idx
;
for
(
index_t
inner_idx
=
0
;
inner_idx
<
inner_size
;
++
inner_idx
)
{
output_ptr
[
cur_idx
+
inner_idx
]
=
output_ptr
[
pre_idx
+
inner_idx
]
+
input_ptr
[
input_idx
+
inner_idx
];
}
}
}
}
}
return
MaceStatus
::
MACE_SUCCESS
;
...
...
mace/ops/cumsum_benchmark.cc
0 → 100644
浏览文件 @
a14a6cb4
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
ops
{
namespace
test
{
class
CumsumOpTest
:
public
OpsTestBase
{};
namespace
{
template
<
DeviceType
D
,
typename
T
>
void
Cumsum
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
mace
::
testing
::
StopTiming
();
// Construct graph
OpsTestNet
net
;
// Add input data
if
(
D
==
DeviceType
::
CPU
)
{
net
.
AddRandomInput
<
D
,
T
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
}
else
{
MACE_NOT_IMPLEMENTED
;
}
OpDefBuilder
(
"Cumsum"
,
"CumsumTest"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
AddIntArg
(
"axis"
,
0
)
.
AddIntArg
(
"exclusive"
,
0
)
.
AddIntArg
(
"reverse"
,
0
)
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
// Warm-up
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
}
}
// namespace
#define MACE_BM_CUMSUM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Cumsum<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_CUMSUM(N, C, H, W) \
MACE_BM_CUMSUM_MACRO(N, C, H, W, float, CPU);
MACE_BM_CUMSUM
(
1
,
1
,
512
,
512
);
MACE_BM_CUMSUM
(
1
,
3
,
128
,
128
);
MACE_BM_CUMSUM
(
1
,
3
,
512
,
512
);
MACE_BM_CUMSUM
(
1
,
32
,
112
,
112
);
MACE_BM_CUMSUM
(
1
,
64
,
256
,
256
);
MACE_BM_CUMSUM
(
1
,
64
,
512
,
512
);
MACE_BM_CUMSUM
(
1
,
128
,
56
,
56
);
MACE_BM_CUMSUM
(
1
,
128
,
256
,
256
);
MACE_BM_CUMSUM
(
1
,
256
,
14
,
14
);
MACE_BM_CUMSUM
(
1
,
512
,
14
,
14
);
MACE_BM_CUMSUM
(
1
,
1024
,
7
,
7
);
MACE_BM_CUMSUM
(
32
,
1
,
256
,
256
);
MACE_BM_CUMSUM
(
32
,
3
,
256
,
256
);
}
// namespace test
}
// namespace ops
}
// namespace mace
mace/ops/cumsum_test.cc
浏览文件 @
a14a6cb4
...
...
@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <vector>
#include "mace/ops/ops_test_util.h"
namespace
mace
{
...
...
@@ -24,33 +21,69 @@ namespace test {
class
CumsumOpTest
:
public
OpsTestBase
{};
namespace
{
void
SimpleTest
()
{
template
<
typename
T
>
void
SimpleTestWithDataFormat
(
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
float
>
&
input
,
const
int
axis
,
const
int
exclusive
,
const
int
reverse
,
const
std
::
vector
<
float
>
&
output
)
{
// Construct graph
OpsTestNet
net
;
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
});
net
.
AddInputFromArray
<
CPU
,
T
>
(
"Input"
,
shape
,
input
);
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
NHWC
,
"InputNCHW"
,
NCHW
);
OpDefBuilder
(
"Cumsum"
,
"CumsumTest"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
AddIntArg
(
"axis"
,
1
)
.
AddIntArg
(
"exclusive"
,
1
)
.
AddIntArg
(
"reverse"
,
1
)
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
float
>::
value
))
.
Input
(
"InputNCHW"
)
.
Output
(
"OutputNCHW"
)
.
AddIntArg
(
"axis"
,
axis
)
.
AddIntArg
(
"exclusive"
,
exclusive
)
.
AddIntArg
(
"reverse"
,
reverse
)
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
AddIntArg
(
"has_data_format"
,
1
)
.
Finalize
(
net
.
NewOperatorDef
());
// Run
net
.
RunOp
(
DeviceType
::
CPU
);
auto
expected
=
net
.
CreateTensor
<
float
>
({
2
,
2
,
2
,
2
},
{
4.
,
5.
,
6.
,
7.
,
0.
,
0.
,
0.
,
0.
,
12.
,
13.
,
14.
,
15.
,
0.
,
0.
,
0.
,
0.
});
ExpectTensorNear
<
float
,
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"OutputNCHW"
,
NCHW
,
"Output"
,
NHWC
);
net
.
AddInputFromArray
<
CPU
,
T
>
(
"ExpectedOutput"
,
shape
,
output
);
ExpectTensorNear
<
T
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
*
net
.
GetOutput
(
"Output"
));
}
}
// namespace
TEST_F
(
CumsumOpTest
,
CPU
)
{
SimpleTest
();
TEST_F
(
CumsumOpTest
,
HasDataFormatCPU
)
{
SimpleTestWithDataFormat
<
float
>
(
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
},
0
,
0
,
0
,
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
10.
,
12.
,
14.
,
16.
,
18.
,
20.
,
22.
});
SimpleTestWithDataFormat
<
float
>
(
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
},
1
,
0
,
0
,
{
0.
,
1.
,
2.
,
3.
,
4.
,
6.
,
8.
,
10.
,
8.
,
9.
,
10.
,
11.
,
20.
,
22.
,
24.
,
26.
});
SimpleTestWithDataFormat
<
float
>
(
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
},
0
,
1
,
0
,
{
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
});
SimpleTestWithDataFormat
<
float
>
(
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
},
0
,
0
,
1
,
{
8.
,
10.
,
12.
,
14.
,
16.
,
18.
,
20.
,
22.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
});
SimpleTestWithDataFormat
<
float
>
(
{
2
,
2
,
2
,
2
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
,
8.
,
9.
,
10.
,
11.
,
12.
,
13.
,
14.
,
15.
},
1
,
1
,
1
,
{
4.
,
5.
,
6.
,
7.
,
0.
,
0.
,
0.
,
0.
,
12.
,
13.
,
14.
,
15.
,
0.
,
0.
,
0.
,
0.
});
}
}
// namespace test
...
...
mace/ops/strided_slice.cc
浏览文件 @
a14a6cb4
...
...
@@ -17,6 +17,7 @@
#include <vector>
#include "mace/core/operator.h"
#include "mace/utils/math.h"
namespace
mace
{
namespace
ops
{
...
...
@@ -33,6 +34,7 @@ class StridedSliceOp : public Operation {
shrink_axis_mask_
(
Operation
::
GetOptionalArg
<
int
>
(
"shrink_axis_mask"
,
0
)),
is_slice_
(
Operation
::
GetOptionalArg
<
bool
>
(
"slice"
,
false
)),
has_data_format_
(
Operation
::
GetOptionalArg
<
int
>
(
"has_data_format"
,
0
)),
checked_
(
false
)
{
MACE_CHECK
(
ellipsis_mask_
==
0
&&
new_axis_mask_
==
0
,
"ellipsis_mask and new_axis_mask are not supported yet."
);
...
...
@@ -62,14 +64,21 @@ class StridedSliceOp : public Operation {
(
*
dims
)[
3
]
=
w
;
}
void
TransposeDimsFromNCHWToNHWC
(
std
::
vector
<
int32_t
>*
dims
)
{
int32_t
c
=
(
*
dims
)[
1
];
int32_t
h
=
(
*
dims
)[
2
];
int32_t
w
=
(
*
dims
)[
3
];
(
*
dims
)[
1
]
=
h
;
(
*
dims
)[
2
]
=
w
;
(
*
dims
)[
3
]
=
c
;
}
MaceStatus
Run
(
OpContext
*
context
)
override
{
MACE_UNUSED
(
context
);
auto
df
=
static_cast
<
DataFormat
>
(
Operation
::
GetOptionalArg
<
int
>
(
"data_format"
,
DataFormat
::
DF_NONE
));
if
(
!
checked_
)
{
if
(
df
==
DataFormat
::
NHWC
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
if
(
has_data_format_
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
TransposeMaskValueFromNHWCToNCHW
(
&
begin_mask_
);
TransposeMaskValueFromNHWCToNCHW
(
&
end_mask_
);
TransposeMaskValueFromNHWCToNCHW
(
&
ellipsis_mask_
);
...
...
@@ -78,14 +87,15 @@ class StridedSliceOp : public Operation {
}
checked_
=
true
;
}
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
begin_indices
=
this
->
Input
(
BEGIN
);
const
Tensor
*
end_indices
=
this
->
Input
(
END
);
const
Tensor
*
strides
=
nullptr
;
if
(
this
->
InputSize
()
>
3
)
{
strides
=
this
->
Input
(
STRIDES
);
}
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
if
(
strides
==
nullptr
)
{
tmp_strides_tensor_
.
Resize
({
begin_indices
->
size
()});
Tensor
::
MappingGuard
strides_guard
(
&
tmp_strides_tensor_
);
...
...
@@ -94,6 +104,11 @@ class StridedSliceOp : public Operation {
strides
=
&
tmp_strides_tensor_
;
}
MACE_CHECK
(
begin_indices
->
dim_size
()
==
1
&&
end_indices
->
dim_size
()
==
1
&&
strides
->
dim_size
()
==
1
,
"Expected begin, end, and strides to be 1D tensor"
);
Tensor
::
MappingGuard
input_guard
(
input
);
Tensor
::
MappingGuard
begin_indices_guard
(
begin_indices
);
Tensor
::
MappingGuard
end_indices_guard
(
end_indices
);
...
...
@@ -102,107 +117,145 @@ class StridedSliceOp : public Operation {
const
int32_t
*
begin_indices_data
=
begin_indices
->
data
<
int32_t
>
();
const
int32_t
*
end_indices_data
=
end_indices
->
data
<
int32_t
>
();
const
int32_t
*
strides_data
=
strides
->
data
<
int32_t
>
();
std
::
vector
<
int32_t
>
pad_begin_indices
(
input
->
dim_size
(),
0
);
std
::
vector
<
int32_t
>
pad_end_indices
(
input
->
dim_size
(),
0
);
std
::
vector
<
int32_t
>
pad_strides_indices
(
input
->
dim_size
(),
1
);
if
(
begin_indices
->
size
()
<
input
->
dim_size
())
{
for
(
index_t
i
=
0
;
i
<
begin_indices
->
size
();
++
i
)
{
pad_begin_indices
[
i
]
=
begin_indices_data
[
i
];
pad_end_indices
[
i
]
=
end_indices_data
[
i
];
pad_strides_indices
[
i
]
=
strides_data
[
i
];
}
for
(
index_t
i
=
begin_indices
->
size
();
i
<
input
->
dim_size
();
++
i
)
{
pad_end_indices
[
i
]
=
input
->
dim
(
i
);
}
begin_indices_data
=
pad_begin_indices
.
data
();
end_indices_data
=
pad_end_indices
.
data
();
strides_data
=
pad_strides_indices
.
data
();
}
std
::
vector
<
int32_t
>
begin_indices_vec
(
begin_indices_data
,
begin_indices_data
+
begin_indices
->
size
());
std
::
vector
<
int32_t
>
end_indices_vec
(
end_indices_data
,
end_indices_data
+
end_indices
->
size
());
std
::
vector
<
int32_t
>
strides_indices_vec
(
strides_data
,
strides_data
+
strides
->
size
());
std
::
vector
<
int32_t
>
transpose_begin_indices
(
input
->
dim_size
(),
0
);
std
::
vector
<
int32_t
>
transpose_end_indices
(
input
->
dim_size
(),
0
);
std
::
vector
<
int32_t
>
transpose_strides_indices
(
input
->
dim_size
(),
1
);
if
(
df
==
DataFormat
::
NHWC
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
for
(
index_t
i
=
0
;
i
<
begin_indices
->
size
();
++
i
)
{
transpose_begin_indices
[
i
]
=
begin_indices_data
[
i
];
transpose_end_indices
[
i
]
=
end_indices_data
[
i
];
transpose_strides_indices
[
i
]
=
strides_data
[
i
];
}
TransposeDimsFromNHWCToNCHW
(
&
transpose_begin_indices
);
TransposeDimsFromNHWCToNCHW
(
&
transpose_end_indices
);
TransposeDimsFromNHWCToNCHW
(
&
transpose_strides_indices
);
MACE_CHECK
(
input
->
size
()
>
0
&&
input
->
dim_size
()
>
0
&&
input
->
dim_size
()
<=
4
,
"The input size should larger than 0."
" And input dims should be an integer in (0, 4]."
);
begin_indices_data
=
transpose_begin_indices
.
data
();
end_indices_data
=
transpose_end_indices
.
data
();
strides_data
=
transpose_strides_indices
.
data
();
}
std
::
vector
<
index_t
>
output_shape
=
{};
std
::
vector
<
int32_t
>
slice_end_data
;
const
size_t
input_dims
=
input
->
dim_size
()
;
if
(
is_slice_
)
{
// if this op is slice, the end_indices_data is size actually
slice_end_data
.
resize
(
end_indices
->
size
());
for
(
size_t
i
=
0
;
i
<
slice_end_data
.
size
();
++
i
)
{
if
(
end_indices_data
[
i
]
==
-
1
)
{
slice_end_data
[
i
]
=
input
->
dim
(
i
);
}
else
{
slice_end_data
[
i
]
=
begin_indices_data
[
i
]
+
end_indices_data
[
i
];
MACE_CHECK
(
begin_indices_vec
.
size
()
==
input_dims
&&
end_indices_vec
.
size
()
==
input_dims
,
"In slice, begin and size elements num should be equal"
);
// transpose
if
(
has_data_format_
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
TransposeDimsFromNHWCToNCHW
(
&
begin_indices_vec
);
TransposeDimsFromNHWCToNCHW
(
&
end_indices_vec
);
TransposeDimsFromNHWCToNCHW
(
&
strides_indices_vec
);
}
for
(
size_t
i
=
0
;
i
<
input_dims
;
++
i
)
{
if
(
end_indices_vec
[
i
]
==
-
1
)
{
end_indices_vec
[
i
]
=
input
->
dim
(
i
)
-
begin_indices_vec
[
i
];
}
}
end_indices_data
=
slice_end_data
.
data
();
}
std
::
vector
<
index_t
>
output_shape
;
std
::
vector
<
index_t
>
real_begin_indices
(
input
->
dim_size
(),
0
);
std
::
vector
<
index_t
>
real_end_indices
(
input
->
dim_size
(),
0
);
for
(
index_t
d
=
0
;
d
<
input
->
dim_size
();
++
d
)
{
index_t
dim_len
=
input
->
dim
(
d
);
if
(
begin_mask_
&
(
1
<<
d
))
{
real_begin_indices
[
d
]
=
strides_data
[
d
]
>
0
?
0
:
dim_len
-
1
;
}
else
{
real_begin_indices
[
d
]
=
(
begin_indices_data
[
d
]
+
dim_len
)
%
dim_len
;
for
(
size_t
i
=
0
;
i
<
input_dims
;
++
i
)
{
int32_t
b
=
begin_indices_vec
[
i
];
int32_t
s
=
end_indices_vec
[
i
];
int32_t
input_i
=
input
->
dim
(
i
);
MACE_CHECK
(
0
<=
b
&&
b
<=
input_i
,
"In Slice, expected begin["
,
i
,
"] in [0, "
,
input_i
,
"], but got "
,
b
);
MACE_CHECK
(
0
<=
s
&&
b
+
s
<=
input_i
,
"In Slice, expected size["
,
i
,
"] in [0, "
,
input_i
-
b
,
"], but got"
,
s
);
end_indices_vec
[
i
]
=
b
+
s
;
output_shape
.
push_back
(
s
);
}
if
(
end_mask_
&
(
1
<<
d
))
{
real_end_indices
[
d
]
=
strides_data
[
d
]
>
0
?
dim_len
:
-
1
;
}
else
{
real_end_indices
[
d
]
=
end_indices_data
[
d
]
<
-
dim_len
?
-
1
:
(
end_indices_data
[
d
]
<
0
?
(
end_indices_data
[
d
]
+
dim_len
)
:
std
::
min
(
static_cast
<
index_t
>
(
end_indices_data
[
d
]),
dim_len
));
}
else
{
MACE_CHECK
(
begin_indices_vec
.
size
()
==
end_indices_vec
.
size
()
&&
end_indices_vec
.
size
()
==
strides_indices_vec
.
size
(),
"In strided_slice, expected begin, end, and strides to be"
,
" equal size tensors"
);
for
(
index_t
i
=
0
;
i
<
strides
->
size
();
++
i
)
{
MACE_CHECK
(
strides_indices_vec
[
i
]
!=
0
,
"strides data cannot be 0!"
);
}
int32_t
out_dim_len
=
std
::
max
(
0.
f
,
std
::
ceil
((
real_end_indices
[
d
]
-
real_begin_indices
[
d
])
/
static_cast
<
float
>
(
strides_data
[
d
])));
if
(
!
(
shrink_axis_mask_
&
(
1
<<
d
)))
{
output_shape
.
push_back
(
out_dim_len
);
}
else
{
MACE_CHECK
(
out_dim_len
==
1
,
"cannot shrink axis that has len > 1, dim("
,
d
,
"): ["
,
real_begin_indices
[
d
],
", "
,
real_end_indices
[
d
],
"]"
);
// pad
begin_indices_vec
.
resize
(
input_dims
,
0
);
strides_indices_vec
.
resize
(
input_dims
,
1
);
std
::
vector
<
int32_t
>
tmp_input_dims
(
input
->
shape
().
begin
(),
input
->
shape
().
end
());
if
(
has_data_format_
&&
input_dims
==
4
)
{
TransposeDimsFromNCHWToNHWC
(
&
tmp_input_dims
);
}
for
(
size_t
i
=
end_indices_vec
.
size
();
i
<
input_dims
;
++
i
)
{
end_indices_vec
.
push_back
(
tmp_input_dims
[
i
]);
}
// transpose
if
(
has_data_format_
&&
this
->
Input
(
0
)
->
dim_size
()
==
4
)
{
TransposeDimsFromNHWCToNCHW
(
&
begin_indices_vec
);
TransposeDimsFromNHWCToNCHW
(
&
end_indices_vec
);
TransposeDimsFromNHWCToNCHW
(
&
strides_indices_vec
);
}
// mask and shrink
for
(
index_t
d
=
0
;
d
<
input
->
dim_size
();
++
d
)
{
index_t
dim_len
=
input
->
dim
(
d
);
const
std
::
vector
<
index_t
>
valid_range
=
{
strides_indices_vec
[
d
]
>
0
?
0
:
-
1
,
strides_indices_vec
[
d
]
>
0
?
dim_len
:
dim_len
-
1
};
auto
format_indices
=
[
valid_range
,
d
,
dim_len
](
index_t
indice
)
{
index_t
forward
=
indice
<
0
?
indice
+
dim_len
:
indice
;
return
Clamp
(
forward
,
valid_range
[
0
],
valid_range
[
1
]);
};
if
(
!
(
shrink_axis_mask_
&
(
1
<<
d
)))
{
if
(
begin_mask_
&
(
1
<<
d
))
{
begin_indices_vec
[
d
]
=
strides_indices_vec
[
d
]
>
0
?
0
:
dim_len
-
1
;
}
else
{
begin_indices_vec
[
d
]
=
format_indices
(
begin_indices_vec
[
d
]);
}
if
(
end_mask_
&
(
1
<<
d
))
{
end_indices_vec
[
d
]
=
strides_indices_vec
[
d
]
>
0
?
dim_len
:
-
1
;
}
else
{
end_indices_vec
[
d
]
=
format_indices
(
end_indices_vec
[
d
]);
}
int32_t
out_dim_len
=
std
::
max
(
0.
f
,
std
::
ceil
((
end_indices_vec
[
d
]
-
begin_indices_vec
[
d
])
/
static_cast
<
float
>
(
strides_indices_vec
[
d
])));
output_shape
.
push_back
(
out_dim_len
);
}
else
{
begin_indices_vec
[
d
]
=
begin_indices_vec
[
d
]
<
0
?
begin_indices_vec
[
d
]
+
dim_len
:
begin_indices_vec
[
d
];
end_indices_vec
[
d
]
=
begin_indices_vec
[
d
]
+
1
;
MACE_CHECK
(
begin_indices_vec
[
d
]
>=
0
&&
begin_indices_vec
[
d
]
<
dim_len
,
"slice begin indice of dimension '"
,
d
,
"': "
,
begin_indices_vec
[
d
],
", is out of bound"
);
}
}
}
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
++
i
)
{
MACE_CHECK
(
output_shape
[
i
]
>
0
,
"Expected output_shape["
,
i
,
"] larger than 0, but got "
,
output_shape
[
i
]);
}
std
::
vector
<
index_t
>
dim_stride
(
input
->
dim_size
(),
1
);
for
(
index_t
d
=
input
->
dim_size
()
-
2
;
d
>=
0
;
--
d
)
{
dim_stride
[
d
]
=
dim_stride
[
d
+
1
]
*
input
->
dim
(
d
+
1
);
}
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
MACE_RETURN_IF_ERROR
(
output
->
Resize
(
output_shape
));
Tensor
::
MappingGuard
output_guard
(
output
);
T
*
output_data
=
output
->
mutable_data
<
T
>
();
bool
slice_by_first_axis
=
true
;
if
(
strides_
data
[
0
]
!=
1
)
{
if
(
strides_
indices_vec
[
0
]
!=
1
)
{
slice_by_first_axis
=
false
;
}
else
{
for
(
index_t
d
=
1
;
d
<
input
->
dim_size
();
++
d
)
{
if
(
strides_
data
[
d
]
!=
1
||
real_begin_indices
[
d
]
!=
0
||
real_end_indices
[
d
]
!=
input
->
dim
(
d
))
{
if
(
strides_
indices_vec
[
d
]
!=
1
||
begin_indices_vec
[
d
]
!=
0
||
end_indices_vec
[
d
]
!=
input
->
dim
(
d
))
{
slice_by_first_axis
=
false
;
break
;
}
...
...
@@ -210,64 +263,64 @@ class StridedSliceOp : public Operation {
}
if
(
slice_by_first_axis
)
{
memcpy
(
output_data
,
input_data
+
real_begin_indices
[
0
]
*
dim_stride
[
0
],
sizeof
(
T
)
*
(
real_end_indices
[
0
]
-
real_begin_indices
[
0
])
*
memcpy
(
output_data
,
input_data
+
begin_indices_vec
[
0
]
*
dim_stride
[
0
],
sizeof
(
T
)
*
(
end_indices_vec
[
0
]
-
begin_indices_vec
[
0
])
*
dim_stride
[
0
]);
}
else
{
if
(
input
->
dim_size
()
==
1
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_
data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
:
i
>
real_end_indices
[
0
];
i
+=
strides_
data
[
0
])
{
for
(
index_t
i
=
begin_indices_vec
[
0
];
strides_
indices_vec
[
0
]
>
0
?
i
<
end_indices_vec
[
0
]
:
i
>
end_indices_vec
[
0
];
i
+=
strides_
indices_vec
[
0
])
{
*
output_data
++
=
input_data
[
i
];
}
}
else
if
(
input
->
dim_size
()
==
2
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_
data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
:
i
>
real_end_indices
[
0
];
i
+=
strides_
data
[
0
])
{
for
(
index_t
j
=
real_begin_indices
[
1
];
strides_
data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
:
j
>
real_end_indices
[
1
];
j
+=
strides_
data
[
1
])
{
for
(
index_t
i
=
begin_indices_vec
[
0
];
strides_
indices_vec
[
0
]
>
0
?
i
<
end_indices_vec
[
0
]
:
i
>
end_indices_vec
[
0
];
i
+=
strides_
indices_vec
[
0
])
{
for
(
index_t
j
=
begin_indices_vec
[
1
];
strides_
indices_vec
[
1
]
>
0
?
j
<
end_indices_vec
[
1
]
:
j
>
end_indices_vec
[
1
];
j
+=
strides_
indices_vec
[
1
])
{
*
output_data
++
=
input_data
[
i
*
input
->
dim
(
1
)
+
j
];
}
}
}
else
if
(
input
->
dim_size
()
==
3
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_
data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
:
i
>
real_end_indices
[
0
];
i
+=
strides_
data
[
0
])
{
for
(
index_t
j
=
real_begin_indices
[
1
];
strides_
data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
:
j
>
real_end_indices
[
1
];
j
+=
strides_
data
[
1
])
{
for
(
index_t
k
=
real_begin_indices
[
2
];
strides_
data
[
2
]
>
0
?
k
<
real_end_indices
[
2
]
:
k
>
real_end_indices
[
2
];
k
+=
strides_
data
[
2
])
{
for
(
index_t
i
=
begin_indices_vec
[
0
];
strides_
indices_vec
[
0
]
>
0
?
i
<
end_indices_vec
[
0
]
:
i
>
end_indices_vec
[
0
];
i
+=
strides_
indices_vec
[
0
])
{
for
(
index_t
j
=
begin_indices_vec
[
1
];
strides_
indices_vec
[
1
]
>
0
?
j
<
end_indices_vec
[
1
]
:
j
>
end_indices_vec
[
1
];
j
+=
strides_
indices_vec
[
1
])
{
for
(
index_t
k
=
begin_indices_vec
[
2
];
strides_
indices_vec
[
2
]
>
0
?
k
<
end_indices_vec
[
2
]
:
k
>
end_indices_vec
[
2
];
k
+=
strides_
indices_vec
[
2
])
{
*
output_data
++
=
input_data
[(
i
*
input
->
dim
(
1
)
+
j
)
*
input
->
dim
(
2
)
+
k
];
}
}
}
}
else
if
(
input
->
dim_size
()
==
4
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_
data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
:
i
>
real_end_indices
[
0
];
i
+=
strides_
data
[
0
])
{
for
(
index_t
j
=
real_begin_indices
[
1
];
strides_
data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
:
j
>
real_end_indices
[
1
];
j
+=
strides_
data
[
1
])
{
for
(
index_t
k
=
real_begin_indices
[
2
];
strides_
data
[
2
]
>
0
?
k
<
real_end_indices
[
2
]
:
k
>
real_end_indices
[
2
];
k
+=
strides_
data
[
2
])
{
for
(
index_t
l
=
real_begin_indices
[
3
];
strides_
data
[
3
]
>
0
?
l
<
real_end_indices
[
3
]
:
l
>
real_end_indices
[
3
];
l
+=
strides_
data
[
3
])
{
for
(
index_t
i
=
begin_indices_vec
[
0
];
strides_
indices_vec
[
0
]
>
0
?
i
<
end_indices_vec
[
0
]
:
i
>
end_indices_vec
[
0
];
i
+=
strides_
indices_vec
[
0
])
{
for
(
index_t
j
=
begin_indices_vec
[
1
];
strides_
indices_vec
[
1
]
>
0
?
j
<
end_indices_vec
[
1
]
:
j
>
end_indices_vec
[
1
];
j
+=
strides_
indices_vec
[
1
])
{
for
(
index_t
k
=
begin_indices_vec
[
2
];
strides_
indices_vec
[
2
]
>
0
?
k
<
end_indices_vec
[
2
]
:
k
>
end_indices_vec
[
2
];
k
+=
strides_
indices_vec
[
2
])
{
for
(
index_t
l
=
begin_indices_vec
[
3
];
strides_
indices_vec
[
3
]
>
0
?
l
<
end_indices_vec
[
3
]
:
l
>
end_indices_vec
[
3
];
l
+=
strides_
indices_vec
[
3
])
{
*
output_data
++
=
input_data
[((
i
*
input
->
dim
(
1
)
+
j
)
*
input
->
dim
(
2
)
+
k
)
*
input
->
dim
(
3
)
+
l
];
...
...
@@ -289,6 +342,7 @@ class StridedSliceOp : public Operation {
int
new_axis_mask_
;
int
shrink_axis_mask_
;
bool
is_slice_
;
int
has_data_format_
;
bool
checked_
;
Tensor
tmp_strides_tensor_
;
...
...
mace/ops/strided_slice_test.cc
浏览文件 @
a14a6cb4
...
...
@@ -64,6 +64,54 @@ void TestStridedSlice(const std::vector<index_t> &input_shape,
*
net
.
GetOutput
(
"Output"
));
}
void
TestStridedSliceWithDataFormat
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
const
std
::
vector
<
int32_t
>
&
end_indices
,
const
std
::
vector
<
int32_t
>
&
strides
,
const
int
begin_mask
,
const
int
end_mask
,
const
int
ellipsis_mask
,
const
int
new_axis_mask
,
const
int
shrink_axis_mask
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
begin_indices
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
static_cast
<
int32_t
>
(
end_indices
.
size
())},
end_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
static_cast
<
int32_t
>
(
strides
.
size
())},
strides
);
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
NHWC
,
"InputNCHW"
,
NCHW
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"InputNCHW"
)
.
Input
(
"BeginIndices"
)
.
Input
(
"EndIndices"
)
.
Input
(
"Strides"
)
.
Output
(
"OutputNCHW"
)
.
AddIntArg
(
"begin_mask"
,
begin_mask
)
.
AddIntArg
(
"end_mask"
,
end_mask
)
.
AddIntArg
(
"ellipsis_mask"
,
ellipsis_mask
)
.
AddIntArg
(
"new_axis_mask"
,
new_axis_mask
)
.
AddIntArg
(
"shrink_axis_mask"
,
shrink_axis_mask
)
.
AddIntArg
(
"has_data_format"
,
1
)
.
Finalize
(
net
.
NewOperatorDef
());
net
.
RunOp
();
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"OutputNCHW"
,
NCHW
,
"Output"
,
NHWC
);
net
.
AddInputFromArray
<
CPU
,
float
>
(
"ExpectedOutput"
,
output_shape
,
output
);
ExpectTensorNear
<
float
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
*
net
.
GetOutput
(
"Output"
));
}
void
TestSlice
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
...
...
@@ -92,6 +140,41 @@ void TestSlice(const std::vector<index_t> &input_shape,
*
net
.
GetOutput
(
"Output"
));
}
void
TestSliceWithDataFormat
(
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
float
>
&
input
,
const
std
::
vector
<
int32_t
>
&
begin_indices
,
const
std
::
vector
<
int32_t
>
&
indices_size
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"IndicesSize"
,
{
static_cast
<
int32_t
>
(
indices_size
.
size
())},
indices_size
);
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
NHWC
,
"InputNCHW"
,
NCHW
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"InputNCHW"
)
.
Input
(
"BeginIndices"
)
.
Input
(
"IndicesSize"
)
.
Output
(
"OutputNCHW"
)
.
AddIntArg
(
"slice"
,
1
)
.
AddIntArg
(
"has_data_format"
,
1
)
.
Finalize
(
net
.
NewOperatorDef
());
net
.
RunOp
();
net
.
TransformDataFormat
<
DeviceType
::
CPU
,
float
>
(
"OutputNCHW"
,
NCHW
,
"Output"
,
NHWC
);
net
.
AddInputFromArray
<
CPU
,
float
>
(
"ExpectedOutput"
,
output_shape
,
output
);
ExpectTensorNear
<
float
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
*
net
.
GetOutput
(
"Output"
));
}
}
// namespace
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceByFirstAxis
)
{
...
...
@@ -157,6 +240,66 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
1
,
2
},
{
1
,
1
,
3
,
3
});
}
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceRank4
)
{
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
,
1
,
2
},
{
15
,
16
,
21
,
22
});
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
3
,
0
,
0
,
0
,
0
,
{
2
,
2
,
1
,
2
},
{
3
,
4
,
9
,
10
,
15
,
16
,
21
,
22
});
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
8
,
0
,
0
,
0
,
{
1
,
2
,
1
,
3
},
{
15
,
16
,
17
,
21
,
22
,
23
});
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
8
,
0
,
0
,
8
,
{
1
,
2
,
1
},
{
15
,
21
});
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
8
,
0
,
0
,
15
,
{},
{
15
});
TestStridedSlice
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
-
1
,
2
,
1
,
3
},
{
0
,
0
,
0
,
0
},
{
-
1
,
-
1
,
-
1
,
-
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
1
,
1
,
2
},
{
23
,
22
});
}
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceWithDataFormat
)
{
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
2
,
1
,
2
},
{
15
,
16
,
21
,
22
});
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
3
,
0
,
0
,
0
,
0
,
{
2
,
2
,
1
,
2
},
{
3
,
4
,
9
,
10
,
15
,
16
,
21
,
22
});
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
2
,
2
,
2
,
2
},
{
1
,
1
,
1
,
1
},
0
,
8
,
0
,
0
,
0
,
{
1
,
2
,
1
,
3
},
{
15
,
16
,
17
,
21
,
22
,
23
});
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
},
{
2
,
1
},
{
1
,
1
},
0
,
8
,
0
,
0
,
0
,
{
1
,
1
,
2
,
3
},
{
12
,
13
,
14
,
15
,
16
,
17
});
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
},
{
2
,
1
},
{
1
,
1
},
0
,
2
,
0
,
0
,
0
,
{
1
,
2
,
2
,
3
},
{
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
});
TestStridedSliceWithDataFormat
(
{
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
-
1
,
2
,
1
,
3
},
{
0
,
0
,
0
,
0
},
{
-
1
,
-
1
,
-
1
,
-
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
1
,
1
,
2
},
{
23
,
22
});
}
TEST_F
(
StridedSliceOpTest
,
TestSlice
)
{
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
...
...
@@ -166,6 +309,17 @@ TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
1
},
{
2
,
-
1
},
{
2
,
2
},
{
2
,
3
,
5
,
6
});
}
TEST_F
(
StridedSliceOpTest
,
TestSliceWithDataFormat
)
{
TestSliceWithDataFormat
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
1
,
2
,
1
,
2
},
{
1
,
2
,
1
,
2
},
{
15
,
16
,
21
,
22
});
TestSliceWithDataFormat
({
2
,
2
,
2
,
3
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
},
{
1
,
0
,
1
,
0
},
{
-
1
,
-
1
,
-
1
,
-
1
},
{
1
,
2
,
1
,
3
},
{
15
,
16
,
17
,
21
,
22
,
23
});
}
}
// namespace test
}
// namespace ops
}
// namespace mace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录