提交 2d62cf4e 编写于 作者: H hjchen2

Optimize float im2col for general kernel and pad

上级 1535b7e7
......@@ -44,11 +44,15 @@ void DequantizeKernel<CPU, float>::Compute(
size_t loop = size >> 4;
size_t remain = size & 0xF;
float32x4_t s = vdupq_n_f32(scale);
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
const int32_t *local_x = x + (i << 4);
float *local_y = y + (i << 4);
int32x4_t r0 = vld1q_s32(local_x);
int32x4_t r1 = vld1q_s32(local_x + 4);
int32x4_t r2 = vld1q_s32(local_x + 8);
int32x4_t r3 = vld1q_s32(local_x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
......@@ -57,14 +61,14 @@ void DequantizeKernel<CPU, float>::Compute(
f1 = vmulq_f32(f1, s);
f2 = vmulq_f32(f2, s);
f3 = vmulq_f32(f3, s);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
x += 16;
y += 16;
vst1q_f32(local_y, f0);
vst1q_f32(local_y + 4, f1);
vst1q_f32(local_y + 8, f2);
vst1q_f32(local_y + 12, f3);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
y[i] = x[i] * scale;
......
......@@ -22,6 +22,70 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extract = (end_width - start_width + stride_w - 1) / stride_w;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
#pragma omp parallel for
for (int i = start_height; i < end_height; i += stride_h) {
const float *local_im_data = im_data + i * im_width * stride_h;
float *local_col_data = col_data + col_width;
if (stride_w == 1) {
memcpy(local_col_data, local_im_data, extract * sizeof(float));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x2_t img = vld2q_f32(local_im_data + s * 2);
vst1q_f32(local_col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x3_t img = vld3q_f32(local_im_data + s * 3);
vst1q_f32(local_col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x4_t img = vld4q_f32(local_im_data + s * 4);
vst1q_f32(local_col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 4];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
......@@ -363,7 +427,22 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
for (int ic = 0; ic < im_channels; ++ic) {
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height,
col_width, padding[0], padding[1], stride[0], stride[1],
kh, kw);
col_data += col_height * col_width;
}
}
im_data += im_height * im_width;
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
......@@ -382,25 +461,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
}
}
}
}
#else
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
#if __ARM_NEON
}
#endif
}
......@@ -489,7 +550,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->data<int8_t>();
int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
// pad 0
......
......@@ -334,11 +334,13 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h)
target_link_libraries(test-fssd paddle-mobile)
# gen test
ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h)
target_link_libraries(test-multi-process paddle-mobile)
# gen test benchmark
ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp)
target_link_libraries(test-benchmark paddle-mobile)
#add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp)
endif ()
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <string>
#include "../test_helper.h"
#include "../test_include.h"
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp;
fp = fopen(file_name, "rb");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: " << std::endl
<< "./test_benchmark fluid_model feed_shape thread_num [use_fuse]"
<< std::endl;
std::cout << "use_fuse: optional, bool, default is 1\n";
return 1;
}
bool optimize = true;
char* fluid_model = argv[1];
char* feed_shape = argv[2];
int thread_num = atoi(argv[3]);
if (argc == 5) {
optimize = atoi(argv[4]);
}
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(thread_num);
auto time1 = time();
if (paddle_mobile.Load(fluid_model, optimize)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::framework::Tensor input;
std::shared_ptr<paddle_mobile::framework::Tensor> output;
std::vector<int64_t> dims{1, 3, 224, 224};
if (feed_shape) {
sscanf(feed_shape, "%d,%d,%d,%d", &dims[0], &dims[1], &dims[2], &dims[3]);
}
std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", "
<< dims[2] << ", " << dims[3] << "]\n";
paddle_mobile::framework::DDim in_shape =
paddle_mobile::framework::make_ddim(dims);
SetupTensor<float>(&input, in_shape, 0.f, 255.f);
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n";
}
return 0;
}
......@@ -20,22 +20,21 @@ int main() {
#ifdef PADDLE_MOBILE_FPGA
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
#endif
#ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
paddle_mobile.SetThreadNum(4);
paddle_mobile.SetThreadNum(1);
bool optimize = true;
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl;
std::cout << "load cost: " << time_diff(time1, time2) << "ms\n";
std::vector<float> input;
std::vector<float> output;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热十次
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input, dims);
}
......@@ -45,8 +44,7 @@ int main() {
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
std::cout << "predict cost: " << time_diff(time3, time4) / 10 << "ms\n";
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册