Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
71266960
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
71266960
编写于
11月 09, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fold embedding lookup
上级
fc7f4967
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
36 addition
and
22 deletion
+36
-22
mace/ops/gather.cc
mace/ops/gather.cc
+9
-16
mace/ops/gather_test.cc
mace/ops/gather_test.cc
+4
-6
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+1
-0
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+22
-0
未找到文件。
mace/ops/gather.cc
浏览文件 @
71266960
...
...
@@ -20,15 +20,11 @@ namespace mace {
namespace
ops
{
template
<
DeviceType
D
,
class
T
>
class
GatherOp
;
template
<
>
class
GatherOp
<
DeviceType
::
CPU
,
float
>
:
public
Operation
{
class
GatherOp
:
public
Operation
{
public:
explicit
GatherOp
(
OpConstructContext
*
context
)
:
Operation
(
context
),
axis_
(
Operation
::
GetOptionalArg
<
int
>
(
"axis"
,
0
)),
y_
(
Operation
::
GetOptionalArg
<
float
>
(
"y"
,
1.0
))
{}
axis_
(
Operation
::
GetOptionalArg
<
int
>
(
"axis"
,
0
))
{}
MaceStatus
Run
(
OpContext
*
context
)
override
{
MACE_UNUSED
(
context
);
...
...
@@ -54,8 +50,8 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
Tensor
::
MappingGuard
params_guard
(
params
);
Tensor
::
MappingGuard
output_guard
(
output
);
const
int32_t
*
indices_data
=
indices
->
data
<
int32_t
>
();
const
float
*
params_data
=
params
->
data
<
float
>
();
float
*
output_data
=
output
->
mutable_data
<
float
>
();
const
T
*
params_data
=
params
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
();
index_t
axis_dim_size
=
params
->
dim
(
axis_
);
index_t
lhs_size
=
std
::
accumulate
(
params
->
shape
().
begin
(),
...
...
@@ -74,23 +70,18 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
memcpy
(
output_data
+
((
l
*
index_size
)
+
idx
)
*
rhs_size
,
params_data
+
((
l
*
axis_dim_size
)
+
indices_data
[
idx
])
*
rhs_size
,
sizeof
(
float
)
*
rhs_size
);
sizeof
(
T
)
*
rhs_size
);
}
}
if
(
std
::
fabs
(
y_
-
1.0
)
>
1e-6
)
{
#pragma omp parallel for
for
(
index_t
i
=
0
;
i
<
output
->
size
();
++
i
)
{
output_data
[
i
]
*=
y_
;
}
}
output
->
SetScale
(
params
->
scale
());
output
->
SetZeroPoint
(
params
->
zero_point
());
return
MaceStatus
::
MACE_SUCCESS
;
}
private:
int
axis_
;
float
y_
;
MACE_OP_INPUT_TAGS
(
PARAMS
,
INDICES
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
};
...
...
@@ -98,6 +89,8 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
void
RegisterGather
(
OpRegistryBase
*
op_registry
)
{
MACE_REGISTER_OP
(
op_registry
,
"Gather"
,
GatherOp
,
DeviceType
::
CPU
,
float
);
MACE_REGISTER_OP
(
op_registry
,
"Gather"
,
GatherOp
,
DeviceType
::
CPU
,
uint8_t
);
}
}
// namespace ops
...
...
mace/ops/gather_test.cc
浏览文件 @
71266960
...
...
@@ -28,7 +28,6 @@ void TestGather(const std::vector<index_t> &weight_shape,
const
std
::
vector
<
index_t
>
&
input_shape
,
const
std
::
vector
<
int32_t
>
&
input
,
const
int
axis
,
const
float
y
,
const
std
::
vector
<
index_t
>
&
output_shape
,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
...
...
@@ -40,7 +39,6 @@ void TestGather(const std::vector<index_t> &weight_shape,
.
Input
(
"Params"
)
.
Input
(
"Indices"
)
.
AddIntArg
(
"axis"
,
axis
)
.
AddFloatArg
(
"y"
,
y
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
// Run
...
...
@@ -55,25 +53,25 @@ void TestGather(const std::vector<index_t> &weight_shape,
TEST_F
(
GatherOpTest
,
CPUScalarIndex
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{},
{
5
},
0
,
2.0
,
{
2
},
{
20
,
22
});
{},
{
5
},
0
,
{
2
},
{
10
,
11
});
}
TEST_F
(
GatherOpTest
,
CPURank1Index
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
3
},
{
2
,
4
,
6
},
0
,
1.0
,
{
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
{
3
},
{
2
,
4
,
6
},
0
,
{
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
}
TEST_F
(
GatherOpTest
,
CPURank1IndexWithAxis1
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
},
{
1
},
1
,
1.0
,
{
10
,
1
},
{
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
,
17
,
19
});
{
1
},
{
1
},
1
,
{
10
,
1
},
{
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
,
17
,
19
});
}
TEST_F
(
GatherOpTest
,
CPURankHighIndex
)
{
TestGather
({
10
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
},
{
1
,
3
},
{
2
,
4
,
6
},
0
,
1.0
,
{
1
,
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
{
1
,
3
},
{
2
,
4
,
6
},
0
,
{
1
,
3
,
2
},
{
4
,
5
,
8
,
9
,
12
,
13
});
}
}
// namespace test
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
71266960
...
...
@@ -222,6 +222,7 @@ class TransformerRule(Enum):
FOLD_DECONV_AND_BN
=
32
FOLD_SQRDIFF_MEAN
=
33
TRANSPOSE_MATMUL_WEIGHT
=
34
FOLD_EMBEDDING_LOOKUP
=
35
class
ConverterInterface
(
object
):
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
71266960
...
...
@@ -79,6 +79,7 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule
.
FLATTEN_ATROUS_CONV
:
self
.
flatten_atrous_conv
,
TransformerRule
.
FOLD_ACTIVATION
:
self
.
fold_activation
,
TransformerRule
.
FOLD_SQRDIFF_MEAN
:
self
.
fold_squared_diff_mean
,
TransformerRule
.
FOLD_EMBEDDING_LOOKUP
:
self
.
fold_embedding_lookup
,
TransformerRule
.
TRANSPOSE_FILTERS
:
self
.
transpose_filters
,
TransformerRule
.
TRANSPOSE_MATMUL_WEIGHT
:
self
.
transpose_matmul_weight
,
...
...
@@ -392,6 +393,27 @@ class Transformer(base_converter.ConverterInterface):
return
False
def
fold_embedding_lookup
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
# gather -> mul
if
(
op
.
type
==
MaceOp
.
Gather
.
name
and
self
.
consumer_count
(
op
.
output
[
0
])
==
1
):
consumer_op
=
self
.
_consumers
[
op
.
output
[
0
]][
0
]
if
(
consumer_op
.
type
==
MaceOp
.
Eltwise
.
name
and
ConverterUtil
.
get_arg
(
consumer_op
,
MaceKeyword
.
mace_element_type_str
).
i
==
EltwiseType
.
PROD
.
value
and
# noqa
len
(
consumer_op
.
input
)
==
1
and
op
.
input
[
0
]
in
self
.
_consts
and
self
.
consumer_count
(
op
.
input
[
0
])
==
1
):
print
(
"Fold Gather and Mul: %s"
%
op
.
name
)
gather_weights
=
self
.
_consts
[
op
.
input
[
0
]]
mul_weight
=
ConverterUtil
.
get_arg
(
consumer_op
,
MaceKeyword
.
mace_scalar_input_str
).
f
# noqa
gather_weights
.
float_data
[:]
=
gather_weights
.
float_data
*
mul_weight
# noqa
self
.
safe_remove_node
(
consumer_op
,
None
,
remove_input_tensor
=
True
)
def
transform_lstmcell_zerostate
(
self
):
net
=
self
.
_model
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录