Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
c384a6e2
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c384a6e2
编写于
4月 20, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'gemv' into 'master'
fix gemv multi-batch case See merge request !398
上级
6c8cc84e
3e867dca
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
103 addition
and
84 deletion
+103
-84
mace/kernels/arm/fully_connected.cc
mace/kernels/arm/fully_connected.cc
+2
-2
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+94
-79
mace/kernels/gemm.h
mace/kernels/gemm.h
+2
-0
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+2
-2
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+1
-1
mace/ops/fully_connected_test.cc
mace/ops/fully_connected_test.cc
+2
-0
未找到文件。
mace/kernels/arm/fully_connected.cc
浏览文件 @
c384a6e2
...
...
@@ -34,10 +34,10 @@ void FullyConnectedFunctor<DeviceType::NEON,
const
float
*
bias_ptr
=
bias
==
nullptr
?
nullptr
:
bias
->
data
<
float
>
();
float
*
output_ptr
=
output
->
mutable_data
<
float
>
();
Gemv
(
weight_ptr
,
input_ptr
,
N
,
input_size
,
output_size
,
output_ptr
);
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
Gemv
(
weight_ptr
,
input_ptr
,
input_size
,
output_size
,
output_ptr
);
for
(
int
j
=
0
;
j
<
output_size
;
++
j
)
{
output_ptr
[
j
]
+=
bias_ptr
[
j
];
output_ptr
[
j
+
i
*
output_size
]
+=
bias_ptr
[
j
];
}
}
...
...
mace/kernels/gemm.cc
浏览文件 @
c384a6e2
...
...
@@ -566,6 +566,7 @@ inline void GemmTile(const float *A,
}
}
// namespace
// A: height x K, B: K x width, C: height x width
void
Gemm
(
const
float
*
A
,
const
float
*
B
,
const
index_t
batch
,
...
...
@@ -573,6 +574,12 @@ void Gemm(const float *A,
const
index_t
K
,
const
index_t
width
,
float
*
C
)
{
if
(
width
==
1
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
Gemv
(
A
+
b
*
height
*
K
,
B
+
b
*
K
,
1
,
K
,
height
,
C
+
b
*
height
);
}
return
;
}
memset
(
C
,
0
,
sizeof
(
float
)
*
batch
*
height
*
width
);
...
...
@@ -628,6 +635,7 @@ void Gemm(const float *A,
}
// n
}
// A: height x K, B: K x width, C: height x width
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
index_t
height
,
...
...
@@ -647,19 +655,24 @@ void GemmRef(const float *A,
void
GemvRef
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
height
,
float
*
out_ptr
)
{
memset
(
out_ptr
,
0
,
sizeof
(
float
)
*
height
);
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
out_ptr
[
h
]
+=
v_ptr
[
w
]
*
m_ptr
[
h
*
width
+
w
];
memset
(
out_ptr
,
0
,
sizeof
(
float
)
*
height
*
batch
);
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
out_ptr
[
h
+
b
*
height
]
+=
v_ptr
[
w
+
b
*
width
]
*
m_ptr
[
h
*
width
+
w
];
}
}
}
}
// M: height x width, Vin: width x 1, Vout: height x 1
void
Gemv
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
height
,
float
*
out_ptr
)
{
...
...
@@ -669,88 +682,90 @@ void Gemv(const float *m_ptr,
index_t
remain_w
=
width
-
(
width_d4
<<
2
);
index_t
remain_h
=
height
-
(
height_d4
<<
2
);
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
#pragma omp parallel for
for
(
index_t
h
=
0
;
h
<
height_d4
;
++
h
)
{
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
*
4
;
const
float
*
m_ptr1
=
m_ptr0
+
width
;
const
float
*
m_ptr2
=
m_ptr1
+
width
;
const
float
*
m_ptr3
=
m_ptr2
+
width
;
const
float
*
v_ptr0
=
v_ptr
;
float
*
out_ptr0
=
out_ptr
+
h
*
4
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vv
;
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum1
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vm1
=
vld1q_f32
(
m_ptr1
);
vm2
=
vld1q_f32
(
m_ptr2
);
vm3
=
vld1q_f32
(
m_ptr3
);
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv
);
vsum1
=
vmlaq_f32
(
vsum1
,
vm1
,
vv
);
vsum2
=
vmlaq_f32
(
vsum2
,
vm2
,
vv
);
vsum3
=
vmlaq_f32
(
vsum3
,
vm3
,
vv
);
m_ptr0
+=
4
;
m_ptr1
+=
4
;
m_ptr2
+=
4
;
m_ptr3
+=
4
;
v_ptr0
+=
4
;
}
float
sum0
=
vaddvq_f32
(
vsum0
);
float
sum1
=
vaddvq_f32
(
vsum1
);
float
sum2
=
vaddvq_f32
(
vsum2
);
float
sum3
=
vaddvq_f32
(
vsum3
);
// handle remaining w
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum1
+=
m_ptr1
[
0
]
*
v_ptr0
[
0
];
sum2
+=
m_ptr2
[
0
]
*
v_ptr0
[
0
];
sum3
+=
m_ptr3
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
m_ptr1
++
;
m_ptr2
++
;
m_ptr3
++
;
v_ptr0
++
;
for
(
index_t
h
=
0
;
h
<
height_d4
;
++
h
)
{
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
*
4
;
const
float
*
m_ptr1
=
m_ptr0
+
width
;
const
float
*
m_ptr2
=
m_ptr1
+
width
;
const
float
*
m_ptr3
=
m_ptr2
+
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
float
*
out_ptr0
=
out_ptr
+
h
*
4
+
b
*
height
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vv
;
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum1
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vm1
=
vld1q_f32
(
m_ptr1
);
vm2
=
vld1q_f32
(
m_ptr2
);
vm3
=
vld1q_f32
(
m_ptr3
);
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv
);
vsum1
=
vmlaq_f32
(
vsum1
,
vm1
,
vv
);
vsum2
=
vmlaq_f32
(
vsum2
,
vm2
,
vv
);
vsum3
=
vmlaq_f32
(
vsum3
,
vm3
,
vv
);
m_ptr0
+=
4
;
m_ptr1
+=
4
;
m_ptr2
+=
4
;
m_ptr3
+=
4
;
v_ptr0
+=
4
;
}
float
sum0
=
vaddvq_f32
(
vsum0
);
float
sum1
=
vaddvq_f32
(
vsum1
);
float
sum2
=
vaddvq_f32
(
vsum2
);
float
sum3
=
vaddvq_f32
(
vsum3
);
// handle remaining w
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum1
+=
m_ptr1
[
0
]
*
v_ptr0
[
0
];
sum2
+=
m_ptr2
[
0
]
*
v_ptr0
[
0
];
sum3
+=
m_ptr3
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
m_ptr1
++
;
m_ptr2
++
;
m_ptr3
++
;
v_ptr0
++
;
}
*
out_ptr0
++
=
sum0
;
*
out_ptr0
++
=
sum1
;
*
out_ptr0
++
=
sum2
;
*
out_ptr0
++
=
sum3
;
}
*
out_ptr0
++
=
sum0
;
*
out_ptr0
++
=
sum1
;
*
out_ptr0
++
=
sum2
;
*
out_ptr0
++
=
sum3
;
}
// handle remaining h
index_t
remain_start_height
=
height_d4
<<
2
;
// handle remaining h
index_t
remain_start_height
=
height_d4
<<
2
;
#pragma omp parallel for
for
(
index_t
h
=
0
;
h
<
remain_h
;
++
h
)
{
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
const
float
*
m_ptr0
=
m_ptr
+
(
h
+
remain_start_height
)
*
width
;
const
float
*
v_ptr0
=
v_ptr
;
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
float32x4_t
vm
=
vld1q_f32
(
m_ptr0
);
float32x4_t
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm
,
vv
);
m_ptr0
+=
4
;
v_ptr0
+=
4
;
}
float
sum
=
vaddvq_f32
(
vsum0
);
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
sum
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
v_ptr0
++
;
for
(
index_t
h
=
0
;
h
<
remain_h
;
++
h
)
{
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
const
float
*
m_ptr0
=
m_ptr
+
(
h
+
remain_start_height
)
*
width
;
const
float
*
v_ptr0
=
v_ptr
;
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
float32x4_t
vm
=
vld1q_f32
(
m_ptr0
);
float32x4_t
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm
,
vv
);
m_ptr0
+=
4
;
v_ptr0
+=
4
;
}
float
sum
=
vaddvq_f32
(
vsum0
);
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
sum
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
v_ptr0
++
;
}
out_ptr
[
remain_start_height
+
h
]
=
sum
;
}
out_ptr
[
remain_start_height
+
h
]
=
sum
;
}
#else
GemvRef
(
m_ptr
,
v_ptr
,
width
,
height
,
out_ptr
);
GemvRef
(
m_ptr
,
v_ptr
,
batch
,
width
,
height
,
out_ptr
);
#endif
}
...
...
mace/kernels/gemm.h
浏览文件 @
c384a6e2
...
...
@@ -41,12 +41,14 @@ void GemmRef(const float *A,
void
Gemv
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
height
,
float
*
out_ptr
);
void
GemvRef
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
height
,
float
*
out_ptr
);
...
...
mace/kernels/gemm_test.cc
浏览文件 @
c384a6e2
...
...
@@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) {
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
K
,
N
,
C
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
K
,
N
,
C_ref
.
get
());
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
1
,
K
,
N
,
C
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
1
,
K
,
N
,
C_ref
.
get
());
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
c384a6e2
...
...
@@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector<index_t> &shape,
for
(
int
kernel_size
:
{
1
,
3
,
5
})
{
for
(
int
stride
:
{
1
,
2
})
{
if
(
stride
<
kernel_size
)
{
if
(
stride
<
=
kernel_size
)
{
func
(
kernel_size
,
kernel_size
,
stride
,
stride
);
}
}
...
...
mace/ops/fully_connected_test.cc
浏览文件 @
c384a6e2
...
...
@@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON
(
1
,
7
,
7
,
32
,
16
);
FullyConnectedTestNEON
(
1
,
7
,
7
,
512
,
128
);
FullyConnectedTestNEON
(
1
,
1
,
1
,
2048
,
1024
);
FullyConnectedTestNEON
(
3
,
1
,
1
,
16
,
8
);
FullyConnectedTestNEON
(
3
,
7
,
7
,
32
,
16
);
}
}
// namespace test
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录