提交 7a507e95 编写于 作者: H hjchen2

fix group conv bug in gemm conv

上级 8bfe4fc7
......@@ -57,6 +57,8 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
......@@ -106,6 +108,10 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
......
......@@ -35,6 +35,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor filter = *param.Filter();
Tensor *output = param.Output();
output->mutable_data<Otype>();
int groups = param.Groups();
const std::vector<int> strides = param.Strides();
const std::vector<int> paddings = param.Paddings();
......@@ -90,8 +91,8 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
// col_matrix.ShareDataWith(in_slice);
col_matrix = in_slice;
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
......@@ -107,6 +108,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), false,
......
......@@ -2971,48 +2971,7 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
// C = A * B
void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
int nc3 = 16 - 4 * (_nc1 % 4);
asm volatile(
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vst1.32 {q0, q1}, [%[C]]! \n\t"
"vld1.32 {q2, q3}, [%[c]]! \n\t"
"vst1.32 {q2, q3}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"vld1.32 {q4}, [%[c]]! \n\t"
"vst1.32 {q4}, [%[C]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
"cmp %[nc3], #16 \n\t"
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q5}, [%[c]]! \n\t"
"vst1.32 {q5}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5");
memcpy(C, c, n * sizeof(float));
}
// C = alpha * A * B + beta * C
......
......@@ -228,39 +228,43 @@ int TestAll(const int in_channels, const int in_height, const int in_width,
std::cerr << "in_channels=" << in_channels << ", in_height=" << in_height
<< ", in_width=" << in_width << ", out_channels=" << out_channels
<< ", groups=" << groups << std::endl;
// // kernel = 3, pad = 0, stride = 1
// std::cerr << "float, kernel=3, pad=0, stride=1" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 0, 1>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 1, stride = 1
// std::cerr << "float, kernel=3, pad=1, stride=1" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 1, 1>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 2, stride = 1
// std::cerr << "float, kernel=3, pad=2, stride=1" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 2, 1>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 5, stride = 1
// std::cerr << "float, kernel=3, pad=5, stride=1" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 5, 1>(
// in_channels, in_height, in_width, out_channels, groups);
//
// // kernel = 3, pad = 0, stride = 2
// std::cerr << "float, kernel=3, pad=0, stride=2" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 0, 2>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 1, stride = 2
// std::cerr << "float, kernel=3, pad=1, stride=2" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 1, 2>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 2, stride = 2
// std::cerr << "float, kernel=3, pad=2, stride=2" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 2, 2>(
// in_channels, in_height, in_width, out_channels, groups);
// // kernel = 3, pad = 5, stride = 2
// std::cerr << "float, kernel=3, pad=5, stride=2" << std::endl;
// paddle_mobile::TestConvOp<float, float, 3, 5, 2>(
// in_channels, in_height, in_width, out_channels, groups);
std::cerr << "float, kernel=1, pad=0, stride=1" << std::endl;
paddle_mobile::TestConvOp<float, float, 1, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 0, stride = 1
std::cerr << "float, kernel=3, pad=0, stride=1" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 1, stride = 1
std::cerr << "float, kernel=3, pad=1, stride=1" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 2, stride = 1
std::cerr << "float, kernel=3, pad=2, stride=1" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 2, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 5, stride = 1
std::cerr << "float, kernel=3, pad=5, stride=1" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 5, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 0, stride = 2
std::cerr << "float, kernel=3, pad=0, stride=2" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 0, 2>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 1, stride = 2
std::cerr << "float, kernel=3, pad=1, stride=2" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 1, 2>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 2, stride = 2
std::cerr << "float, kernel=3, pad=2, stride=2" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 2, 2>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 3, pad = 5, stride = 2
std::cerr << "float, kernel=3, pad=5, stride=2" << std::endl;
paddle_mobile::TestConvOp<float, float, 3, 5, 2>(
in_channels, in_height, in_width, out_channels, groups);
#ifndef __aarch64__
// kernel = 3, pad = 0, stride = 1
......@@ -338,6 +342,7 @@ int TestAll(const int in_channels, const int in_height, const int in_width,
}
int main() {
TestAll(16, 10, 10, 16, 16);
TestAll(1, 5, 5, 1, 1);
TestAll(1, 5, 5, 10, 1);
TestAll(10, 5, 5, 10, 10);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册