提交 9bbac6ea 编写于 作者: W wuchenghui

implement fc with gemv and convert global conv to fc

上级 6cfa7ab0
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <random> #include <random>
#include <algorithm> #include <algorithm>
#include <memory>
#include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/arm/conv_winograd.h"
#include "mace/core/types.h" #include "mace/core/types.h"
...@@ -25,51 +26,46 @@ TEST(ConvWinogradTest, winograd) { ...@@ -25,51 +26,46 @@ TEST(ConvWinogradTest, winograd) {
index_t filter_size = 3 * 3 * in_channels * out_channels; index_t filter_size = 3 * 3 * in_channels * out_channels;
index_t output_size = batch * out_channels * out_height * out_width; index_t output_size = batch * out_channels * out_height * out_width;
float *input_data = new float[input_size]; std::unique_ptr<float[]> input_data(new float[input_size]);
float *filter_data = new float[filter_size]; std::unique_ptr<float[]> filter_data(new float[filter_size]);
float *output_data = new float[output_size]; std::unique_ptr<float[]> output_data(new float[output_size]);
float *output_data_ref = new float[output_size]; std::unique_ptr<float[]> output_data_ref(new float[output_size]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(input_data, input_data + input_size, std::generate(input_data.get(), input_data.get() + input_size,
[&gen, &nd] { [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen))); return std::max(-1.0f, std::min(1.0f, nd(gen)));
}); });
std::generate(filter_data, filter_data + filter_size, std::generate(filter_data.get(), filter_data.get() + filter_size,
[&gen, &nd] { [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen))); return std::max(-1.0f, std::min(1.0f, nd(gen)));
}); });
kernels::ConvRef3x3s1(input_data, kernels::ConvRef3x3s1(input_data.get(),
filter_data, filter_data.get(),
batch, batch,
in_height, in_height,
in_width, in_width,
in_channels, in_channels,
out_channels, out_channels,
output_data_ref); output_data_ref.get());
kernels::WinoGradConv3x3s1(input_data, kernels::WinoGradConv3x3s1(input_data.get(),
filter_data, filter_data.get(),
batch, batch,
in_height, in_height,
in_width, in_width,
in_channels, in_channels,
out_channels, out_channels,
6, 6,
output_data); output_data.get());
// test // test
for (index_t i = 0; i < output_size; ++i) { for (index_t i = 0; i < output_size; ++i) {
EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1) << " with index " << i; EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1) << " with index " << i;
} }
delete[]input_data;
delete[]filter_data;
delete[]output_data;
delete[]output_data_ref;
} }
} // namespace kernels } // namespace kernels
......
...@@ -25,7 +25,7 @@ void FullyConnectedFunctor<DeviceType::NEON, ...@@ -25,7 +25,7 @@ void FullyConnectedFunctor<DeviceType::NEON,
float *output_ptr = output->mutable_data<float>(); float *output_ptr = output->mutable_data<float>();
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
Gemm(weight_ptr, input_ptr, 1, output_size, input_size, 1, output_ptr); Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr);
for (int j = 0; j < output_size; ++j) { for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j]; output_ptr[j] += bias_ptr[j];
} }
......
...@@ -514,5 +514,115 @@ void GemmRef(const float *A, ...@@ -514,5 +514,115 @@ void GemmRef(const float *A,
} }
} }
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t width,
const index_t height,
float *out_ptr) {
memset(out_ptr, 0, sizeof(float) * height);
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h] += v_ptr[w] * m_ptr[h * width + w];
}
}
}
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t width,
const index_t height,
float *out_ptr) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
index_t height_d4 = height >> 2;
index_t width_d4 = width >> 2;
index_t remain_w = width - (width_d4 << 2);
index_t remain_h = height - (height_d4 << 2);
#pragma omp parallel for
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr;
float *out_ptr0 = out_ptr + h * 4;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) {
vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2);
vm3 = vld1q_f32(m_ptr3);
vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv);
vsum1 = vmlaq_f32(vsum1, vm1, vv);
vsum2 = vmlaq_f32(vsum2, vm2, vv);
vsum3 = vmlaq_f32(vsum3, vm3, vv);
m_ptr0 += 4;
m_ptr1 += 4;
m_ptr2 += 4;
m_ptr3 += 4;
v_ptr0 += 4;
}
float sum0 = vaddvq_f32(vsum0);
float sum1 = vaddvq_f32(vsum1);
float sum2 = vaddvq_f32(vsum2);
float sum3 = vaddvq_f32(vsum3);
// handle remaining w
for (index_t w = 0; w < remain_w; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0];
sum3 += m_ptr3[0] * v_ptr0[0];
m_ptr0++;
m_ptr1++;
m_ptr2++;
m_ptr3++;
v_ptr0++;
}
*out_ptr0++ = sum0;
*out_ptr0++ = sum1;
*out_ptr0++ = sum2;
*out_ptr0++ = sum3;
}
// handle remaining h
index_t remain_start_height = height_d4 << 2;
#pragma omp parallel for
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv);
m_ptr0 += 4;
v_ptr0 += 4;
}
float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) {
sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
}
out_ptr[remain_start_height + h] = sum;
}
#else
GemvRef(m_ptr, v_ptr, width, height, out_ptr);
#endif
}
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -29,6 +29,18 @@ void GemmRef(const float *A, ...@@ -29,6 +29,18 @@ void GemmRef(const float *A,
const index_t width, const index_t width,
float *C); float *C);
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t width,
const index_t height,
float *out_ptr);
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t width,
const index_t height,
float *out_ptr);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <random> #include <random>
#include <memory>
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "mace/core/types.h" #include "mace/core/types.h"
...@@ -14,33 +15,58 @@ TEST(GEMMTest, gemm) { ...@@ -14,33 +15,58 @@ TEST(GEMMTest, gemm) {
index_t N = 17; index_t N = 17;
index_t M = 33; index_t M = 33;
index_t K = 64; index_t K = 64;
float *A = new float[N * K]; std::unique_ptr<float[]> A(new float[N * K]);
float *B = new float[K * M]; std::unique_ptr<float[]> B(new float[K * M]);
float *C = new float[N * M]; std::unique_ptr<float[]> C(new float[N * M]);
float *C_ref = new float[N * M]; std::unique_ptr<float[]> C_ref(new float[N * M]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(A, A + N * K, std::generate(A.get(), A.get() + N * K,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
std::generate(B, B + K * M, std::generate(B.get(), B.get() + K * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
kernels::Gemm(A, B, 1, N, K, M, C); kernels::Gemm(A.get(), B.get(), 1, N, K, M, C.get());
kernels::GemmRef(A, B, N, K, M, C_ref); kernels::GemmRef(A.get(), B.get(), N, K, M, C_ref.get());
for (int i = 0; i < N * M; ++i) { for (int i = 0; i < N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1); EXPECT_NEAR(C_ref[i], C[i], 0.1);
} }
}
TEST(GEMMTest, gemv) {
index_t N = 17;
index_t K = 63;
std::unique_ptr<float[]> A(new float[N * K]);
std::unique_ptr<float[]> B(new float[K]);
std::unique_ptr<float[]> C(new float[N]);
std::unique_ptr<float[]> C_ref(new float[N]);
delete[]A; std::random_device rd;
delete[]B; std::mt19937 gen(rd());
delete[]C; std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K,
[&gen, &nd] {
return nd(gen);
});
std::generate(B.get(), B.get() + K,
[&gen, &nd] {
return nd(gen);
});
kernels::Gemv(A.get(), B.get(), K, N, C.get());
kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get());
for (int i = 0; i < N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
} }
} // namespace mace } // namespace mace
...@@ -308,7 +308,7 @@ void FullyConnectedTestNEON(const index_t batch, ...@@ -308,7 +308,7 @@ void FullyConnectedTestNEON(const index_t batch,
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"), ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"), *net.GetOutput("OutputNeon"),
0.001); 0.01);
} }
TEST_F(FullyConnectedOpTest, TestNEON) { TEST_F(FullyConnectedOpTest, TestNEON) {
......
...@@ -101,8 +101,13 @@ class Shapes(object): ...@@ -101,8 +101,13 @@ class Shapes(object):
return output_shape return output_shape
@staticmethod @staticmethod
def fully_connected_shape(input_shape, weight_shape): def fully_connected_shape(input_shape, weight_shape, input_format='NHWC'):
return [input_shape[0], 1, 1, weight_shape[0]] if input_format == 'NHWC':
return [input_shape[0], 1, 1, weight_shape[0]]
elif input_format == 'NCHW':
return [input_shape[0], weight_shape[0], 1, 1]
else:
raise Exception("format %s is not supported" % input_format)
@staticmethod @staticmethod
def concat_shape(input_shapes, axis): def concat_shape(input_shapes, axis):
...@@ -445,6 +450,18 @@ class CaffeConverter(object): ...@@ -445,6 +450,18 @@ class CaffeConverter(object):
final_op.output_shape_map[final_op.layer.top[0]] = output_shape final_op.output_shape_map[final_op.layer.top[0]] = output_shape
self.resolved_ops.add(activation_op.name) self.resolved_ops.add(activation_op.name)
if op_def.type in ("Conv2D", "FusedConv2D") and \
output_shape[2] == 1 and \
((input_format == 'NCHW' and output_shape[3] == 1) or
(input_format == 'NHWC' and output_shape[1] == 1)):
print "convert op %s from CONV to FC" % op.name
op_def.type = 'FC'
filter_shape = weight_data.shape
new_shape = [filter_shape[0],
filter_shape[1] * filter_shape[2] * filter_shape[3],
1, 1]
weight_data.reshape(new_shape)
op_def.output.extend([final_op.name + ':0']) op_def.output.extend([final_op.name + ':0'])
self.add_output_shape(op_def, output_shape) self.add_output_shape(op_def, output_shape)
self.net_def.op.extend([op_def]) self.net_def.op.extend([op_def])
......
...@@ -402,6 +402,19 @@ class TFConverter(object): ...@@ -402,6 +402,19 @@ class TFConverter(object):
final_op = op final_op = op
self.resolved_ops[op.name] = 1 self.resolved_ops[op.name] = 1
# convert global conv to fc
filter_shape = get_input_tensor(op, 1).shape.as_list()
input_shape = get_input_tensor(op, 0).shape.as_list()
if op_def.type == "Conv2D" and input_shape[1] == filter_shape[0] and \
input_shape[2] == filter_shape[1] and \
(op.get_attr('padding') == 'VALID' or filter_shape[0] == 1 and
filter_shape[1] == 1):
print "convert op %s from CONV to FC" % op.name
op_def.type = 'FC'
self.reshape_tensor[get_input_tensor(op, 1).name] = \
[filter_shape[3],
filter_shape[2] * filter_shape[1] * filter_shape[0], 1, 1]
if len(self.tf_graph.get(op.name, [])) == 1 and \ if len(self.tf_graph.get(op.name, [])) == 1 and \
self.tf_graph[op.name][0].type == 'BiasAdd': self.tf_graph[op.name][0].type == 'BiasAdd':
bias_add_op = self.tf_graph[op.name][0] bias_add_op = self.tf_graph[op.name][0]
......
...@@ -190,7 +190,7 @@ DEFINE_int32(restart_round, 1, "restart round"); ...@@ -190,7 +190,7 @@ DEFINE_int32(restart_round, 1, "restart round");
DEFINE_int32(malloc_check_cycle, -1, "malloc debug check cycle, -1 to disable"); DEFINE_int32(malloc_check_cycle, -1, "malloc debug check cycle, -1 to disable");
DEFINE_int32(gpu_perf_hint, 2, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH"); DEFINE_int32(gpu_perf_hint, 2, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH");
DEFINE_int32(gpu_priority_hint, 1, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH"); DEFINE_int32(gpu_priority_hint, 1, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH");
DEFINE_int32(omp_num_threads, 8, "num of openmp threads"); DEFINE_int32(omp_num_threads, 4, "num of openmp threads");
DEFINE_int32(cpu_power_option, DEFINE_int32(cpu_power_option,
0, 0,
"0:DEFAULT/1:HIGH_PERFORMANCE/2:BATTERY_SAVE"); "0:DEFAULT/1:HIGH_PERFORMANCE/2:BATTERY_SAVE");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册