Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
92f18fc6
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
92f18fc6
编写于
8月 16, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support tf basic lstm on cpu
上级
b77d694f
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
328 addition
and
69 deletion
+328
-69
mace/kernels/fill.h
mace/kernels/fill.h
+72
-0
mace/kernels/opencl/cl/split.cl
mace/kernels/opencl/cl/split.cl
+1
-1
mace/kernels/opencl/split.cc
mace/kernels/opencl/split.cc
+8
-8
mace/kernels/split.h
mace/kernels/split.h
+9
-9
mace/kernels/strided_slice.h
mace/kernels/strided_slice.h
+18
-2
mace/ops/fill.cc
mace/ops/fill.cc
+29
-0
mace/ops/fill.h
mace/ops/fill.h
+49
-0
mace/ops/fill_test.cc
mace/ops/fill_test.cc
+65
-0
mace/ops/ops_register.cc
mace/ops/ops_register.cc
+4
-2
mace/ops/split.cc
mace/ops/split.cc
+8
-8
mace/ops/split.h
mace/ops/split.h
+9
-9
mace/ops/split_benchmark.cc
mace/ops/split_benchmark.cc
+16
-16
mace/ops/split_test.cc
mace/ops/split_test.cc
+8
-8
mace/ops/strided_slice_test.cc
mace/ops/strided_slice_test.cc
+12
-0
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+3
-0
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+15
-5
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+1
-0
repository/opencl-kernel/opencl_kernel_configure.bzl
repository/opencl-kernel/opencl_kernel_configure.bzl
+1
-1
未找到文件。
mace/kernels/fill.h
0 → 100644
浏览文件 @
92f18fc6
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_FILL_H_
#define MACE_KERNELS_FILL_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace
mace
{
namespace
kernels
{
struct
FillBase
{
explicit
FillBase
(
float
value
)
:
value_
(
value
)
{}
int
value_
;
};
template
<
DeviceType
D
,
class
T
>
struct
FillFunctor
;
template
<
>
struct
FillFunctor
<
DeviceType
::
CPU
,
float
>
:
FillBase
{
explicit
FillFunctor
(
float
value
)
:
FillBase
(
value
)
{}
MaceStatus
operator
()(
const
Tensor
*
shape
,
Tensor
*
output
,
StatsFuture
*
future
)
{
MACE_UNUSED
(
future
);
MACE_CHECK
(
shape
->
dim_size
()
==
1
)
<<
"Shape must be 1-D"
;
const
index_t
num_dims
=
shape
->
dim
(
0
);
Tensor
::
MappingGuard
shape_guard
(
shape
);
const
int32_t
*
shape_data
=
shape
->
data
<
int32_t
>
();
std
::
vector
<
index_t
>
output_shape
;
for
(
index_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
MACE_CHECK
(
shape_data
[
i
]
>
0
)
<<
"Shape must be non-negative: "
<<
shape_data
[
i
];
output_shape
.
push_back
(
shape_data
[
i
]);
}
MACE_RETURN_IF_ERROR
(
output
->
Resize
(
output_shape
));
Tensor
::
MappingGuard
output_guard
(
output
);
float
*
output_data
=
output
->
mutable_data
<
float
>
();
std
::
fill
(
output_data
,
output_data
+
output
->
size
(),
value_
);
return
MACE_SUCCESS
;
}
};
}
// namespace kernels
}
// namespace mace
#endif // MACE_KERNELS_FILL_H_
mace/kernels/opencl/cl/s
lice
.cl
→
mace/kernels/opencl/cl/s
plit
.cl
浏览文件 @
92f18fc6
#
include
<common.h>
#
include
<common.h>
__kernel
void
s
lice
(
KERNEL_ERROR_PARAMS
__kernel
void
s
plit
(
KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only
image2d_t
input,
__read_only
image2d_t
input,
__private
const
int
chan_blk_offset,
__private
const
int
chan_blk_offset,
...
...
mace/kernels/opencl/s
lice
.cc
→
mace/kernels/opencl/s
plit
.cc
浏览文件 @
92f18fc6
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "mace/kernels/s
lice
.h"
#include "mace/kernels/s
plit
.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/tuner.h"
...
@@ -21,7 +21,7 @@ namespace mace {
...
@@ -21,7 +21,7 @@ namespace mace {
namespace
kernels
{
namespace
kernels
{
template
<
typename
T
>
template
<
typename
T
>
MaceStatus
S
lice
Functor
<
DeviceType
::
GPU
,
T
>::
operator
()(
MaceStatus
S
plit
Functor
<
DeviceType
::
GPU
,
T
>::
operator
()(
const
Tensor
*
input
,
const
Tensor
*
input
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
StatsFuture
*
future
)
{
StatsFuture
*
future
)
{
...
@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
...
@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
const
size_t
outputs_count
=
output_list
.
size
();
const
size_t
outputs_count
=
output_list
.
size
();
const
index_t
output_channels
=
input_channels
/
outputs_count
;
const
index_t
output_channels
=
input_channels
/
outputs_count
;
MACE_CHECK
(
output_channels
%
4
==
0
)
MACE_CHECK
(
output_channels
%
4
==
0
)
<<
"output channels of s
lice
op must be divisible by 4"
;
<<
"output channels of s
plit
op must be divisible by 4"
;
std
::
vector
<
index_t
>
output_shape
(
std
::
vector
<
index_t
>
output_shape
(
{
input
->
dim
(
0
),
input
->
dim
(
1
),
input
->
dim
(
2
),
output_channels
});
{
input
->
dim
(
0
),
input
->
dim
(
1
),
input
->
dim
(
2
),
output_channels
});
...
@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
...
@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
std
::
set
<
std
::
string
>
built_options
;
std
::
set
<
std
::
string
>
built_options
;
OUT_OF_RANGE_CONFIG
(
kernel_error_
);
OUT_OF_RANGE_CONFIG
(
kernel_error_
);
NON_UNIFORM_WG_CONFIG
;
NON_UNIFORM_WG_CONFIG
;
std
::
string
kernel_name
=
MACE_OBFUSCATE_SYMBOL
(
"s
lice
"
);
std
::
string
kernel_name
=
MACE_OBFUSCATE_SYMBOL
(
"s
plit
"
);
built_options
.
emplace
(
"-Ds
lice
="
+
kernel_name
);
built_options
.
emplace
(
"-Ds
plit
="
+
kernel_name
);
built_options
.
emplace
(
"-DDATA_TYPE="
+
DtToCLDt
(
DataTypeToEnum
<
T
>::
value
));
built_options
.
emplace
(
"-DDATA_TYPE="
+
DtToCLDt
(
DataTypeToEnum
<
T
>::
value
));
built_options
.
emplace
(
"-DCMD_DATA_TYPE="
+
built_options
.
emplace
(
"-DCMD_DATA_TYPE="
+
DtToCLCMDDt
(
DataTypeToEnum
<
T
>::
value
));
DtToCLCMDDt
(
DataTypeToEnum
<
T
>::
value
));
MACE_RETURN_IF_ERROR
(
runtime
->
BuildKernel
(
"s
lice
"
,
MACE_RETURN_IF_ERROR
(
runtime
->
BuildKernel
(
"s
plit
"
,
kernel_name
,
kernel_name
,
built_options
,
built_options
,
&
kernel_
));
&
kernel_
));
...
@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
...
@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
return
MACE_SUCCESS
;
return
MACE_SUCCESS
;
}
}
template
struct
S
lice
Functor
<
DeviceType
::
GPU
,
float
>;
template
struct
S
plit
Functor
<
DeviceType
::
GPU
,
float
>;
template
struct
S
lice
Functor
<
DeviceType
::
GPU
,
half
>;
template
struct
S
plit
Functor
<
DeviceType
::
GPU
,
half
>;
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/s
lice
.h
→
mace/kernels/s
plit
.h
浏览文件 @
92f18fc6
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#ifndef MACE_KERNELS_S
LICE
_H_
#ifndef MACE_KERNELS_S
PLIT
_H_
#define MACE_KERNELS_S
LICE
_H_
#define MACE_KERNELS_S
PLIT
_H_
#include <memory>
#include <memory>
#include <functional>
#include <functional>
...
@@ -31,15 +31,15 @@
...
@@ -31,15 +31,15 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
struct
S
lice
FunctorBase
{
struct
S
plit
FunctorBase
{
explicit
S
lice
FunctorBase
(
const
int32_t
axis
)
:
axis_
(
axis
)
{}
explicit
S
plit
FunctorBase
(
const
int32_t
axis
)
:
axis_
(
axis
)
{}
int32_t
axis_
;
int32_t
axis_
;
};
};
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
S
liceFunctor
:
Slice
FunctorBase
{
struct
S
plitFunctor
:
Split
FunctorBase
{
explicit
S
liceFunctor
(
const
int32_t
axis
)
:
Slice
FunctorBase
(
axis
)
{}
explicit
S
plitFunctor
(
const
int32_t
axis
)
:
Split
FunctorBase
(
axis
)
{}
MaceStatus
operator
()(
const
Tensor
*
input
,
MaceStatus
operator
()(
const
Tensor
*
input
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
...
@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase {
...
@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase {
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
template
<
typename
T
>
template
<
typename
T
>
struct
S
liceFunctor
<
DeviceType
::
GPU
,
T
>
:
Slice
FunctorBase
{
struct
S
plitFunctor
<
DeviceType
::
GPU
,
T
>
:
Split
FunctorBase
{
explicit
S
liceFunctor
(
const
int32_t
axis
)
:
Slice
FunctorBase
(
axis
)
{}
explicit
S
plitFunctor
(
const
int32_t
axis
)
:
Split
FunctorBase
(
axis
)
{}
MaceStatus
operator
()(
const
Tensor
*
input
,
MaceStatus
operator
()(
const
Tensor
*
input
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
const
std
::
vector
<
Tensor
*>
&
output_list
,
...
@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
...
@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_S
LICE
_H_
#endif // MACE_KERNELS_S
PLIT
_H_
mace/kernels/strided_slice.h
浏览文件 @
92f18fc6
...
@@ -169,7 +169,6 @@ struct StridedSliceFunctor {
...
@@ -169,7 +169,6 @@ struct StridedSliceFunctor {
i
+=
strides_data
[
0
])
{
i
+=
strides_data
[
0
])
{
*
output_data
++
=
input_data
[
i
];
*
output_data
++
=
input_data
[
i
];
}
}
}
else
if
(
input
->
dim_size
()
==
2
)
{
}
else
if
(
input
->
dim_size
()
==
2
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
strides_data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
...
@@ -179,7 +178,24 @@ struct StridedSliceFunctor {
...
@@ -179,7 +178,24 @@ struct StridedSliceFunctor {
strides_data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
strides_data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
:
j
>
real_end_indices
[
1
];
:
j
>
real_end_indices
[
1
];
j
+=
strides_data
[
1
])
{
j
+=
strides_data
[
1
])
{
*
output_data
++
=
input_data
[
i
*
dim_stride
[
0
]
+
j
];
*
output_data
++
=
input_data
[
i
*
input
->
dim
(
1
)
+
j
];
}
}
}
else
if
(
input
->
dim_size
()
==
3
)
{
for
(
index_t
i
=
real_begin_indices
[
0
];
strides_data
[
0
]
>
0
?
i
<
real_end_indices
[
0
]
:
i
>
real_end_indices
[
0
];
i
+=
strides_data
[
0
])
{
for
(
index_t
j
=
real_begin_indices
[
1
];
strides_data
[
1
]
>
0
?
j
<
real_end_indices
[
1
]
:
j
>
real_end_indices
[
1
];
j
+=
strides_data
[
1
])
{
for
(
index_t
k
=
real_begin_indices
[
2
];
strides_data
[
2
]
>
0
?
k
<
real_end_indices
[
2
]
:
k
>
real_end_indices
[
2
];
k
+=
strides_data
[
2
])
{
*
output_data
++
=
input_data
[(
i
*
input
->
dim
(
1
)
+
j
)
*
input
->
dim
(
2
)
+
k
];
}
}
}
}
}
}
else
{
}
else
{
...
...
mace/ops/fill.cc
0 → 100644
浏览文件 @
92f18fc6
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/fill.h"
namespace
mace
{
namespace
ops
{
void
Register_Fill
(
OperatorRegistryBase
*
op_registry
)
{
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"Fill"
)
.
Device
(
DeviceType
::
CPU
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
FillOp
<
DeviceType
::
CPU
,
float
>
);
}
}
// namespace ops
}
// namespace mace
mace/ops/fill.h
0 → 100644
浏览文件 @
92f18fc6
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_FILL_H_
#define MACE_OPS_FILL_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/fill.h"
namespace
mace
{
namespace
ops
{
template
<
DeviceType
D
,
class
T
>
class
FillOp
:
public
Operator
<
D
,
T
>
{
public:
FillOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
functor_
(
OperatorBase
::
GetOptionalArg
<
float
>
(
"value"
,
0.0
f
))
{}
MaceStatus
Run
(
StatsFuture
*
future
)
override
{
const
Tensor
*
shape
=
this
->
Input
(
SHAPE
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
return
functor_
(
shape
,
output
,
future
);
}
private:
kernels
::
FillFunctor
<
D
,
T
>
functor_
;
MACE_OP_INPUT_TAGS
(
SHAPE
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
};
}
// namespace ops
}
// namespace mace
#endif // MACE_OPS_FILL_H_
mace/ops/fill_test.cc
0 → 100644
浏览文件 @
92f18fc6
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
ops
{
namespace
test
{
class
FillTest
:
public
OpsTestBase
{};
namespace
{
void
TestFill
(
const
std
::
vector
<
int32_t
>
&
shape
,
const
float
&
value
)
{
// Construct graph
OpsTestNet
net
;
OpDefBuilder
(
"Fill"
,
"FillTest"
)
.
Input
(
"Shape"
)
.
AddFloatArg
(
"value"
,
static_cast
<
float
>
(
value
))
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
// Add input data
net
.
AddInputFromArray
<
DeviceType
::
CPU
,
int32_t
>
(
"Shape"
,
{
static_cast
<
index_t
>
(
shape
.
size
())},
shape
);
// Run
net
.
RunOp
();
auto
output
=
net
.
GetTensor
(
"Output"
);
for
(
index_t
i
=
0
;
i
<
output
->
dim_size
();
++
i
)
{
ASSERT_EQ
(
output
->
dim
(
i
),
shape
[
i
]);
}
const
float
*
output_ptr
=
output
->
data
<
float
>
();
const
index_t
size
=
output
->
size
();
for
(
index_t
i
=
0
;
i
<
size
;
++
i
)
{
ASSERT_EQ
(
output_ptr
[
i
],
value
);
}
}
}
// namespace
TEST_F
(
FillTest
,
Simple
)
{
TestFill
({
3
,
2
,
1
},
5.0
f
);
TestFill
({
1
,
3
},
-
1.0
f
);
}
}
// namespace test
}
// namespace ops
}
// namespace mace
mace/ops/ops_register.cc
浏览文件 @
92f18fc6
...
@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
...
@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern
void
Register_DepthwiseConv2d
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_DepthwiseConv2d
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Dequantize
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Dequantize
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Eltwise
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Eltwise
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Fill
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_FoldedBatchNorm
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_FoldedBatchNorm
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_FullyConnected
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_FullyConnected
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Gather
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Gather
(
OperatorRegistryBase
*
op_registry
);
...
@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
...
@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern
void
Register_Reshape
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Reshape
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_ResizeBilinear
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_ResizeBilinear
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Shape
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Shape
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_S
lice
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_S
plit
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Softmax
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Softmax
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Stack
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_Stack
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_StridedSlice
(
OperatorRegistryBase
*
op_registry
);
extern
void
Register_StridedSlice
(
OperatorRegistryBase
*
op_registry
);
...
@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
...
@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops
::
Register_DepthwiseConv2d
(
this
);
ops
::
Register_DepthwiseConv2d
(
this
);
ops
::
Register_Dequantize
(
this
);
ops
::
Register_Dequantize
(
this
);
ops
::
Register_Eltwise
(
this
);
ops
::
Register_Eltwise
(
this
);
ops
::
Register_Fill
(
this
);
ops
::
Register_FoldedBatchNorm
(
this
);
ops
::
Register_FoldedBatchNorm
(
this
);
ops
::
Register_FullyConnected
(
this
);
ops
::
Register_FullyConnected
(
this
);
ops
::
Register_Gather
(
this
);
ops
::
Register_Gather
(
this
);
...
@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
...
@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops
::
Register_Reshape
(
this
);
ops
::
Register_Reshape
(
this
);
ops
::
Register_ResizeBilinear
(
this
);
ops
::
Register_ResizeBilinear
(
this
);
ops
::
Register_Shape
(
this
);
ops
::
Register_Shape
(
this
);
ops
::
Register_S
lice
(
this
);
ops
::
Register_S
plit
(
this
);
ops
::
Register_Softmax
(
this
);
ops
::
Register_Softmax
(
this
);
ops
::
Register_Stack
(
this
);
ops
::
Register_Stack
(
this
);
ops
::
Register_StridedSlice
(
this
);
ops
::
Register_StridedSlice
(
this
);
...
...
mace/ops/s
lice
.cc
→
mace/ops/s
plit
.cc
浏览文件 @
92f18fc6
...
@@ -12,30 +12,30 @@
...
@@ -12,30 +12,30 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "mace/ops/s
lice
.h"
#include "mace/ops/s
plit
.h"
namespace
mace
{
namespace
mace
{
namespace
ops
{
namespace
ops
{
void
Register_S
lice
(
OperatorRegistryBase
*
op_registry
)
{
void
Register_S
plit
(
OperatorRegistryBase
*
op_registry
)
{
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
lice
"
)
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
plit
"
)
.
Device
(
DeviceType
::
CPU
)
.
Device
(
DeviceType
::
CPU
)
.
TypeConstraint
<
float
>
(
"T"
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
.
Build
(),
S
lice
Op
<
DeviceType
::
CPU
,
float
>
);
S
plit
Op
<
DeviceType
::
CPU
,
float
>
);
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
lice
"
)
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
plit
"
)
.
Device
(
DeviceType
::
GPU
)
.
Device
(
DeviceType
::
GPU
)
.
TypeConstraint
<
float
>
(
"T"
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
.
Build
(),
S
lice
Op
<
DeviceType
::
GPU
,
float
>
);
S
plit
Op
<
DeviceType
::
GPU
,
float
>
);
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
lice
"
)
MACE_REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"S
plit
"
)
.
Device
(
DeviceType
::
GPU
)
.
Device
(
DeviceType
::
GPU
)
.
TypeConstraint
<
half
>
(
"T"
)
.
TypeConstraint
<
half
>
(
"T"
)
.
Build
(),
.
Build
(),
S
lice
Op
<
DeviceType
::
GPU
,
half
>
);
S
plit
Op
<
DeviceType
::
GPU
,
half
>
);
#endif // MACE_ENABLE_OPENCL
#endif // MACE_ENABLE_OPENCL
}
}
...
...
mace/ops/s
lice
.h
→
mace/ops/s
plit
.h
浏览文件 @
92f18fc6
...
@@ -12,21 +12,21 @@
...
@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#ifndef MACE_OPS_S
LICE
_H_
#ifndef MACE_OPS_S
PLIT
_H_
#define MACE_OPS_S
LICE
_H_
#define MACE_OPS_S
PLIT
_H_
#include <vector>
#include <vector>
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/kernels/s
lice
.h"
#include "mace/kernels/s
plit
.h"
namespace
mace
{
namespace
mace
{
namespace
ops
{
namespace
ops
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
S
lice
Op
:
public
Operator
<
D
,
T
>
{
class
S
plit
Op
:
public
Operator
<
D
,
T
>
{
public:
public:
S
lice
Op
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
S
plit
Op
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
functor_
(
OperatorBase
::
GetOptionalArg
<
int
>
(
"axis"
,
3
))
{}
functor_
(
OperatorBase
::
GetOptionalArg
<
int
>
(
"axis"
,
3
))
{}
...
@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> {
...
@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> {
<<
"There must be at least two outputs for slicing"
;
<<
"There must be at least two outputs for slicing"
;
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
std
::
vector
<
Tensor
*>
output_list
=
this
->
Outputs
();
const
std
::
vector
<
Tensor
*>
output_list
=
this
->
Outputs
();
const
int32_t
s
lice
_axis
=
OperatorBase
::
GetOptionalArg
<
int
>
(
"axis"
,
3
);
const
int32_t
s
plit
_axis
=
OperatorBase
::
GetOptionalArg
<
int
>
(
"axis"
,
3
);
MACE_CHECK
((
input
->
dim
(
s
lice
_axis
)
%
this
->
OutputSize
())
==
0
)
MACE_CHECK
((
input
->
dim
(
s
plit
_axis
)
%
this
->
OutputSize
())
==
0
)
<<
"Outputs do not split input equally."
;
<<
"Outputs do not split input equally."
;
return
functor_
(
input
,
output_list
,
future
);
return
functor_
(
input
,
output_list
,
future
);
}
}
private:
private:
kernels
::
S
lice
Functor
<
D
,
T
>
functor_
;
kernels
::
S
plit
Functor
<
D
,
T
>
functor_
;
private:
private:
MACE_OP_INPUT_TAGS
(
INPUT
);
MACE_OP_INPUT_TAGS
(
INPUT
);
...
@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> {
...
@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> {
}
// namespace ops
}
// namespace ops
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_S
LICE
_H_
#endif // MACE_OPS_S
PLIT
_H_
mace/ops/s
lice
_benchmark.cc
→
mace/ops/s
plit
_benchmark.cc
浏览文件 @
92f18fc6
...
@@ -22,7 +22,7 @@ namespace test {
...
@@ -22,7 +22,7 @@ namespace test {
namespace
{
namespace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
void
BMS
lice
Helper
(
int
iters
,
void
BMS
plit
Helper
(
int
iters
,
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
index_t
>
&
input_shape
,
const
index_t
num_outputs
)
{
const
index_t
num_outputs
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
...
@@ -42,7 +42,7 @@ void BMSliceHelper(int iters,
...
@@ -42,7 +42,7 @@ void BMSliceHelper(int iters,
BufferToImage
<
D
,
T
>
(
&
net
,
"Input"
,
"InputImage"
,
BufferToImage
<
D
,
T
>
(
&
net
,
"Input"
,
"InputImage"
,
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
auto
builder
=
OpDefBuilder
(
"S
lice"
,
"Slice
Test"
);
auto
builder
=
OpDefBuilder
(
"S
plit"
,
"Split
Test"
);
builder
.
Input
(
"InputImage"
);
builder
.
Input
(
"InputImage"
);
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
builder
=
builder
.
Output
(
MakeString
(
"OutputImage"
,
i
));
builder
=
builder
.
Output
(
MakeString
(
"OutputImage"
,
i
));
...
@@ -51,7 +51,7 @@ void BMSliceHelper(int iters,
...
@@ -51,7 +51,7 @@ void BMSliceHelper(int iters,
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
.
Finalize
(
net
.
NewOperatorDef
());
}
else
{
}
else
{
auto
builder
=
OpDefBuilder
(
"S
lice"
,
"Slice
Test"
);
auto
builder
=
OpDefBuilder
(
"S
plit"
,
"Split
Test"
);
builder
.
Input
(
"Input"
);
builder
.
Input
(
"Input"
);
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
builder
=
builder
.
Output
(
MakeString
(
"Output"
,
i
));
builder
=
builder
.
Output
(
MakeString
(
"Output"
,
i
));
...
@@ -73,28 +73,28 @@ void BMSliceHelper(int iters,
...
@@ -73,28 +73,28 @@ void BMSliceHelper(int iters,
}
}
}
// namespace
}
// namespace
#define MACE_BM_S
LICE
_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
#define MACE_BM_S
PLIT
_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
static void \
static void \
MACE_BM_S
LICE
_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
MACE_BM_S
PLIT
_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMS
lice
Helper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
BMS
plit
Helper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
} \
} \
MACE_BENCHMARK( \
MACE_BENCHMARK( \
MACE_BM_S
LICE
_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
MACE_BM_S
PLIT
_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
#define MACE_BM_S
LICE
(N, H, W, C, NO) \
#define MACE_BM_S
PLIT
(N, H, W, C, NO) \
MACE_BM_S
LICE
_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_S
PLIT
_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_S
LICE
_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_S
PLIT
_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_S
LICE
_MACRO(N, H, W, C, NO, half, GPU);
MACE_BM_S
PLIT
_MACRO(N, H, W, C, NO, half, GPU);
MACE_BM_S
LICE
(
1
,
32
,
32
,
32
,
2
);
MACE_BM_S
PLIT
(
1
,
32
,
32
,
32
,
2
);
MACE_BM_S
LICE
(
1
,
32
,
32
,
128
,
2
);
MACE_BM_S
PLIT
(
1
,
32
,
32
,
128
,
2
);
MACE_BM_S
LICE
(
1
,
32
,
32
,
256
,
2
);
MACE_BM_S
PLIT
(
1
,
32
,
32
,
256
,
2
);
MACE_BM_S
LICE
(
1
,
128
,
128
,
32
,
2
);
MACE_BM_S
PLIT
(
1
,
128
,
128
,
32
,
2
);
MACE_BM_S
LICE
(
1
,
128
,
128
,
128
,
2
);
MACE_BM_S
PLIT
(
1
,
128
,
128
,
128
,
2
);
}
// namespace test
}
// namespace test
}
// namespace ops
}
// namespace ops
...
...
mace/ops/s
lice
_test.cc
→
mace/ops/s
plit
_test.cc
浏览文件 @
92f18fc6
...
@@ -17,13 +17,13 @@
...
@@ -17,13 +17,13 @@
#include "gmock/gmock.h"
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/s
lice
.h"
#include "mace/ops/s
plit
.h"
namespace
mace
{
namespace
mace
{
namespace
ops
{
namespace
ops
{
namespace
test
{
namespace
test
{
class
S
lice
OpTest
:
public
OpsTestBase
{};
class
S
plit
OpTest
:
public
OpsTestBase
{};
namespace
{
namespace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
...
@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) {
...
@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) {
BufferToImage
<
D
,
T
>
(
&
net
,
"Input"
,
"InputImage"
,
BufferToImage
<
D
,
T
>
(
&
net
,
"Input"
,
"InputImage"
,
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
auto
builder
=
OpDefBuilder
(
"S
lice"
,
"Slice
Test"
);
auto
builder
=
OpDefBuilder
(
"S
plit"
,
"Split
Test"
);
builder
.
Input
(
"InputImage"
);
builder
.
Input
(
"InputImage"
);
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
builder
=
builder
.
Output
(
MakeString
(
"OutputImage"
,
i
));
builder
=
builder
.
Output
(
MakeString
(
"OutputImage"
,
i
));
...
@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) {
...
@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) {
builder
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
builder
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
.
Finalize
(
net
.
NewOperatorDef
());
}
else
{
}
else
{
auto
builder
=
OpDefBuilder
(
"S
lice"
,
"Slice
Test"
).
AddIntArg
(
"axis"
,
axis
);
auto
builder
=
OpDefBuilder
(
"S
plit"
,
"Split
Test"
).
AddIntArg
(
"axis"
,
axis
);
builder
.
Input
(
"Input"
);
builder
.
Input
(
"Input"
);
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
builder
=
builder
.
Output
(
MakeString
(
"Output"
,
i
));
builder
=
builder
.
Output
(
MakeString
(
"Output"
,
i
));
...
@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) {
...
@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) {
}
}
}
// namespace
}
// namespace
TEST_F
(
S
lice
OpTest
,
CPU
)
{
TEST_F
(
S
plit
OpTest
,
CPU
)
{
RandomTest
<
DeviceType
::
CPU
,
float
>
(
2
,
3
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
2
,
3
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
4
,
3
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
4
,
3
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
11
,
3
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
11
,
3
);
}
}
TEST_F
(
S
lice
OpTest
,
CPUAxis1
)
{
TEST_F
(
S
plit
OpTest
,
CPUAxis1
)
{
RandomTest
<
DeviceType
::
CPU
,
float
>
(
2
,
1
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
2
,
1
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
4
,
1
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
4
,
1
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
11
,
1
);
RandomTest
<
DeviceType
::
CPU
,
float
>
(
11
,
1
);
}
}
TEST_F
(
S
lice
OpTest
,
OPENCLFloat
)
{
TEST_F
(
S
plit
OpTest
,
OPENCLFloat
)
{
RandomTest
<
DeviceType
::
GPU
,
float
>
(
2
,
3
);
RandomTest
<
DeviceType
::
GPU
,
float
>
(
2
,
3
);
RandomTest
<
DeviceType
::
GPU
,
float
>
(
4
,
3
);
RandomTest
<
DeviceType
::
GPU
,
float
>
(
4
,
3
);
RandomTest
<
DeviceType
::
GPU
,
float
>
(
11
,
3
);
RandomTest
<
DeviceType
::
GPU
,
float
>
(
11
,
3
);
}
}
TEST_F
(
S
lice
OpTest
,
OPENCLHalf
)
{
TEST_F
(
S
plit
OpTest
,
OPENCLHalf
)
{
RandomTest
<
DeviceType
::
GPU
,
half
>
(
2
,
3
);
RandomTest
<
DeviceType
::
GPU
,
half
>
(
2
,
3
);
RandomTest
<
DeviceType
::
GPU
,
half
>
(
4
,
3
);
RandomTest
<
DeviceType
::
GPU
,
half
>
(
4
,
3
);
RandomTest
<
DeviceType
::
GPU
,
half
>
(
11
,
3
);
RandomTest
<
DeviceType
::
GPU
,
half
>
(
11
,
3
);
...
...
mace/ops/strided_slice_test.cc
浏览文件 @
92f18fc6
...
@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
...
@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
0
,
3
,
{},
{
6
});
0
,
3
,
{},
{
6
});
}
}
TEST_F
(
StridedSliceOpTest
,
TestStridedSliceRank3
)
{
TestStridedSlice
({
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
{
0
,
0
,
0
},
{
2
,
3
,
2
},
{
1
,
2
,
1
},
0
,
0
,
0
,
0
,
0
,
{
2
,
2
,
2
},
{
1
,
2
,
5
,
6
,
7
,
8
,
11
,
12
});
TestStridedSlice
({
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
4
,
4
,
4
,
5
,
5
,
5
,
6
,
6
,
6
},
{
1
,
0
,
0
},
{
2
,
1
,
3
},
{
1
,
1
,
1
},
0
,
0
,
0
,
0
,
0
,
{
1
,
1
,
3
},
{
3
,
3
,
3
});
TestStridedSlice
({
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
4
,
4
,
4
,
5
,
5
,
5
,
6
,
6
,
6
},
{
0
,
0
,
0
},
{
2
,
2
,
2
},
{
1
,
2
,
1
},
0
,
0
,
0
,
0
,
0
,
{
2
,
1
,
2
},
{
1
,
1
,
3
,
3
});
}
TEST_F
(
StridedSliceOpTest
,
TestSlice
)
{
TEST_F
(
StridedSliceOpTest
,
TestSlice
)
{
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
2
,
3
},
TestSlice
({
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
0
,
0
},
{
2
,
3
},
{
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
});
{
1
,
2
,
3
,
4
,
5
,
6
});
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
92f18fc6
...
@@ -88,6 +88,7 @@ MaceSupportedOps = [
...
@@ -88,6 +88,7 @@ MaceSupportedOps = [
'Dequantize'
,
'Dequantize'
,
'Eltwise'
,
'Eltwise'
,
'FoldedBatchNorm'
,
'FoldedBatchNorm'
,
'Fill'
,
'FullyConnected'
,
'FullyConnected'
,
'Gather'
,
'Gather'
,
'Identity'
,
'Identity'
,
...
@@ -101,6 +102,7 @@ MaceSupportedOps = [
...
@@ -101,6 +102,7 @@ MaceSupportedOps = [
'Reshape'
,
'Reshape'
,
'ResizeBilinear'
,
'ResizeBilinear'
,
'Slice'
,
'Slice'
,
'Split'
,
'Shape'
,
'Shape'
,
'Squeeze'
,
'Squeeze'
,
'Stack'
,
'Stack'
,
...
@@ -146,6 +148,7 @@ class MaceKeyword(object):
...
@@ -146,6 +148,7 @@ class MaceKeyword(object):
mace_constant_value_str
=
'constant_value'
mace_constant_value_str
=
'constant_value'
mace_dims_str
=
'dims'
mace_dims_str
=
'dims'
mace_axis_str
=
'axis'
mace_axis_str
=
'axis'
mace_num_split_str
=
'num_split'
mace_keepdims_str
=
'keepdims'
mace_keepdims_str
=
'keepdims'
mace_shape_str
=
'shape'
mace_shape_str
=
'shape'
mace_winograd_filter_transformed
=
'is_filter_transformed'
mace_winograd_filter_transformed
=
'is_filter_transformed'
...
...
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
92f18fc6
...
@@ -68,6 +68,7 @@ TFSupportedOps = [
...
@@ -68,6 +68,7 @@ TFSupportedOps = [
'Relu6'
,
'Relu6'
,
'Tanh'
,
'Tanh'
,
'Sigmoid'
,
'Sigmoid'
,
'Fill'
,
'FusedBatchNorm'
,
'FusedBatchNorm'
,
'AvgPool'
,
'AvgPool'
,
'MaxPool'
,
'MaxPool'
,
...
@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
...
@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType
.
Relu6
.
name
:
self
.
convert_activation
,
TFOpType
.
Relu6
.
name
:
self
.
convert_activation
,
TFOpType
.
Tanh
.
name
:
self
.
convert_activation
,
TFOpType
.
Tanh
.
name
:
self
.
convert_activation
,
TFOpType
.
Sigmoid
.
name
:
self
.
convert_activation
,
TFOpType
.
Sigmoid
.
name
:
self
.
convert_activation
,
TFOpType
.
Fill
.
name
:
self
.
convert_fill
,
TFOpType
.
FusedBatchNorm
.
name
:
self
.
convert_fused_batchnorm
,
TFOpType
.
FusedBatchNorm
.
name
:
self
.
convert_fused_batchnorm
,
TFOpType
.
AvgPool
.
name
:
self
.
convert_pooling
,
TFOpType
.
AvgPool
.
name
:
self
.
convert_pooling
,
TFOpType
.
MaxPool
.
name
:
self
.
convert_pooling
,
TFOpType
.
MaxPool
.
name
:
self
.
convert_pooling
,
...
@@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
...
@@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
limit_arg
.
name
=
MaceKeyword
.
mace_activation_max_limit_str
limit_arg
.
name
=
MaceKeyword
.
mace_activation_max_limit_str
limit_arg
.
f
=
6.0
limit_arg
.
f
=
6.0
def
convert_fill
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Fill
.
name
value_arg
=
op
.
arg
.
add
()
value_arg
.
name
=
MaceKeyword
.
mace_value_str
value_arg
.
f
=
tf_op
.
inputs
[
1
].
eval
()
def
convert_fused_batchnorm
(
self
,
tf_op
):
def
convert_fused_batchnorm
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
FoldedBatchNorm
.
name
op
.
type
=
MaceOp
.
FoldedBatchNorm
.
name
...
@@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
...
@@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
op
.
output_type
.
extend
([
mace_pb2
.
DT_INT32
])
op
.
output_type
.
extend
([
mace_pb2
.
DT_INT32
])
def
convert_split
(
self
,
tf_op
):
def
convert_split
(
self
,
tf_op
):
# inputs: [dim, input]
axis
=
tf_op
.
inputs
[
0
].
eval
().
astype
(
np
.
int32
)
axis
=
tf_op
.
inputs
[
0
].
eval
().
astype
(
np
.
int32
)
axis
=
len
(
op
.
output_shape
[
0
].
dims
)
+
axis
if
axis
<
0
else
axis
axis
=
len
(
op
.
output_shape
[
0
].
dims
)
+
axis
if
axis
<
0
else
axis
mace_check
(
axis
==
3
,
'Split with %d axis only support'
%
axis
)
input_shape
=
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
1
])
input_shape
=
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
1
])
mace_check
(
len
(
input_shape
)
==
4
and
(
input_shape
[
3
]
%
4
==
0
),
"The input's 4th dimension should be a multiple of 4"
)
op
=
self
.
convert_general_op
(
tf_op
)
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
S
lice
.
name
op
.
type
=
MaceOp
.
S
plit
.
name
del
op
.
input
[
0
]
del
op
.
input
[
0
]
axis_arg
=
op
.
arg
.
add
()
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis_arg
.
i
=
axis
axis_arg
.
i
=
axis
num_split_arg
=
op
.
arg
.
add
()
num_split_arg
.
name
=
MaceKeyword
.
mace_num_split_str
num_split_arg
.
i
=
tf_op
.
get_attr
(
'num_split'
)
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
0
].
name
)
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
0
].
name
)
mace/python/tools/converter_tool/transformer.py
浏览文件 @
92f18fc6
...
@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at "
"only support concat at "
"channel dimension"
)
"channel dimension"
)
arg
.
i
=
3
arg
.
i
=
3
producer
=
self
.
_producer
[
op
.
input
[
0
]]
producer
=
self
.
_producer
[
op
.
input
[
0
]]
input_shape
=
producer
.
output_shape
[
0
].
dims
input_shape
=
producer
.
output_shape
[
0
].
dims
if
producer
.
type
==
MaceOp
.
FullyConnected
.
name
and
\
if
producer
.
type
==
MaceOp
.
FullyConnected
.
name
and
\
...
...
repository/opencl-kernel/opencl_kernel_configure.bzl
浏览文件 @
92f18fc6
...
@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
...
@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/pooling.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/pooling.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/reduce_mean.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/reduce_mean.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/resize_bilinear.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/resize_bilinear.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/s
lice
.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/s
plit
.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/softmax.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/softmax.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/space_to_batch.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/space_to_batch.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/winograd_transform.cl"
))
unused_var
=
repository_ctx
.
path
(
Label
(
"//:mace/kernels/opencl/cl/winograd_transform.cl"
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录