diff --git a/CMakeLists.txt b/CMakeLists.txt index 6feabdbe4374c9200c4282f620fadc27f3128bc9..43aea26f59802c4e58cecaf2313288ba2d1f307b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,11 +7,20 @@ option(USE_EXCEPTION "use std exception" ON) option(LOG_PROFILE "log profile" ON) # select the platform to build option(CPU "armv7 with neon" ON) -option(MALI_GPU "mali gpu" ON) +option(MALI_GPU "mali gpu" OFF) option(FPGA "fpga" OFF) set(DEBUGING ON) + +file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c) +file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h) + if (CPU) add_definitions(-DPADDLE_MOBILE_CPU) +else() + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.h) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.cc) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.cpp) + endif() if (MALI_GPU) @@ -27,15 +36,24 @@ if (MALI_GPU) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -L${ACL_ROOT}/build/opencl-1.2-stubs") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lOpenCL") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ACL=1") +else() + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.h) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.cc) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.cpp) + + endif() if(FPGA) - add_definitions(-DPADDLE_MOBILE_FPGA) + add_definitions(-DPADDLE_MOBILE_FPGA) +else() + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.h) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.cc) + list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.cpp) endif() set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}") - if (DEBUGING) message(STATUS "debug") set(CMAKE_BUILD_TYPE Debug) @@ -69,8 +87,7 @@ if(USE_OPENMP) endif() -file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c) -file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h) + if (NOT ANDROID_NDK_TOOLCHAIN_INCLUDED) list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/jni/*.cpp) diff --git a/src/common/common.h b/src/common/common.h new file mode 100644 index 0000000000000000000000000000000000000000..12157b5e946490d041f0cc0d235142a13a3a2527 --- /dev/null +++ b/src/common/common.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +using Time = decltype(std::chrono::high_resolution_clock::now()); + +inline Time time() { return std::chrono::high_resolution_clock::now(); } + +inline double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} diff --git a/src/common/log.h b/src/common/log.h index a3cefe2541e310897ce753b8eb69711242762122..faab6b31a31d7ce7148f96630900aff82931c771 100644 --- a/src/common/log.h +++ b/src/common/log.h @@ -120,7 +120,7 @@ struct ToLog { if (level > paddle_mobile::log_level) { \ } else \ paddle_mobile::ToLog( \ - level, static_cast( \ + level, static_cast( \ std::stringstream() \ << "[file: " \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \ @@ -133,7 +133,7 @@ struct ToLog { } else \ paddle_mobile::ToLog( \ paddle_mobile::kLOG_DEBUG, \ - static_cast( \ + static_cast( \ std::stringstream() \ << "[file: " \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \ diff --git a/src/framework/tensor.h b/src/framework/tensor.h index a221a26aa1435000646cf7d58321df28f3322834..9bbd81aa30f6fa0188dacd0dce01813e17b9e339 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -22,6 +22,7 @@ limitations under the License. */ #include #include "common/enforce.h" +#include #include "common/enforce.h" #include "framework/data_layout.h" #include "framework/ddim.h" @@ -131,6 +132,22 @@ class Tensor { return reinterpret_cast(mutable_data(typeid(T))); } +#ifdef PADDLE_MOBILE_DEBUG + template + inline void dump(std::string filename) const { + const T *dataptr = data(); + std::ofstream out(filename.c_str()); + for (int i = 0; i < numel(); ++i) { + out << dataptr[i] << " "; + } + out << "形状:"; + for (int j = 0; j < dims_.size(); ++j) { + out << dims_[j] << " "; + } + out.close(); + } +#endif + inline void *mutable_data(std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 24f1d3f63b3300db9b60a595466a0ced3b9e996b..73107a3c0adc382dea98663188215ad295c4506b 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -11,9 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define FUSION_CONVADD_OP -#ifdef FUSION_CONVADD_OP +#ifdef FUSION_CONVADD_OP #pragma once #include diff --git a/src/operators/kernel/arm/conv_add_kernel.cpp b/src/operators/kernel/arm/conv_add_kernel.cpp index 4bde8289007415dccbc7a630c7646ac718087c55..2c7aef932dc68e7a29bf60760751be0f9598cd42 100644 --- a/src/operators/kernel/arm/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_kernel.cpp @@ -23,8 +23,7 @@ bool ConvAddKernel::Init(const FusionConvAddParam ¶) const { return true; } -template <> -void ConvAddKernel::Compute(const FusionConvAddParam ¶m) const { +void ConvAddBasic(const FusionConvAddParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor bias = *param.Bias(); @@ -102,7 +101,6 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) const { // vol2col vol2col(in_slice, dilations, strides, paddings, &col); } - // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); @@ -112,6 +110,26 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) const { } } } + +template <> +void ConvAddKernel::Compute(const FusionConvAddParam ¶m) const { + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + param.Bias(), true); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3) { + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), param.Bias(), param.Output(), true); + } else { + ConvAddBasic(param); + } +} + template class ConvAddKernel; } // namespace operators diff --git a/src/operators/kernel/conv_add_kernel.h b/src/operators/kernel/conv_add_kernel.h index 8f733f245dc26664ce38413a09fc5404029cdd2f..fb161238fee0550a42cd62cc132d6e8dbf45872f 100644 --- a/src/operators/kernel/conv_add_kernel.h +++ b/src/operators/kernel/conv_add_kernel.h @@ -20,9 +20,11 @@ limitations under the License. */ #if __ARM_NEON #include #endif +#include "common/common.h" #include "framework/ddim.h" #include "framework/operator.h" #include "operators/math/conv_func.h" +#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c37cdea8fae1b5ec139cefbec82511ce948bff5 --- /dev/null +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -0,0 +1,506 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "operators/math/depthwise_conv_3x3.h" +#include + +namespace paddle_mobile { +namespace operators { +namespace math { +void DepthwiseConv3x3(const Tensor *input, vector strides, + vector paddings, const Tensor *filter, Tensor *bias, + Tensor *output, bool if_bias) { +#if __ARM_NEON + const int batch_size = input->dims()[0]; + + const int input_height = input->dims()[2]; + + const int input_width = input->dims()[3]; + + const int output_channels = output->dims()[1]; + + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + const int _kernel_size = 3; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const float zero = 0; + const int input_channel_stride = input_height * input_width; + const int output_channel_stride = output_height * output_width; + const int filter_channel_stride = 9; + + const float *input_data = input->data(); + const float *filter_data = filter->data(); + if (if_bias) { + math::expand_bias(*bias, 1, output->dims()); + output->ShareDataWith(*bias); + } + float *output_data = output->mutable_data(); + + const int input_batch_stride = output_channels * input_channel_stride; + const int output_batch_stride = output_channels * output_channel_stride; + const int filter_batch_stride = output_channels * output_channel_stride; + const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr; + int hstart, wstart, hend, wend; + float result; + for (int i = 0; i < batch_size; ++i) { + for (int c = 0; c < output_channels; ++c) { + filter1 = filter_data; + filter2 = filter1 + 3; + filter3 = filter2 + 3; + + for (int ph = 0; ph < output_height; ph++) { + for (int pw = 0; pw < output_width; pw++) { + hstart = ph * stride_height - padding_height; + wstart = pw * stride_width - padding_width; + hend = min(hstart + _kernel_size, input_height + padding_height); + wend = min(wstart + _kernel_size, input_width + padding_width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, input_height); + wend = min(wend, input_width); + pos1 = input_data + hstart * input_width + wstart; + pos2 = input_data + (hstart + 1) * input_width + wstart; + pos3 = input_data + (hstart + 2) * input_width + wstart; + output_ptr = output_data + ph * output_width + pw; + + if (hend - hstart != 3 || wend - wstart != 3) { + result = 0; + float fake_input[9] = {0}; + if (hstart == 0 && wstart == 0) { + // 左上角 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j >= 3 - hend && k >= 3 - wend) { + fake_input[3 * j + k] = + input_data[(j - (3 - hend)) * input_width + k - + (3 - wend)]; + } + } + } + } else if (hstart == 0 && wend == input_width) { + // 右上角 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j >= 3 - hend && k <= input_width - wstart - 1) { + fake_input[3 * j + k] = + input_data[(j - (3 - hend)) * input_width + k + wstart]; + } + } + } + + } else if (hend == input_height && wstart == 0) { + // 左下角 + + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j <= input_height - 1 - hstart && k >= 3 - wend) { + fake_input[3 * j + k] = + input_data[(j + hstart) * input_width + k - (3 - wend)]; + } + } + } + } else if (hend == input_height && wend == input_width) { + // 右下角 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j <= input_height - hstart - 1 && + k <= input_width - wstart - 1) { + fake_input[3 * j + k] = + input_data[(j + hstart) * input_width + k + wstart]; + } + } + } + } else if (hstart == 0) { + // 顶部 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j >= 3 - hend) { + fake_input[3 * j + k] = + input_data[(j - (3 - hend)) * input_width + k + wstart]; + } + } + } + + } else if (hend == input_height) { + // 底部 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (j <= input_height - hstart - 1) { + fake_input[3 * j + k] = + input_data[(j + hstart) * input_width + k + wstart]; + } + } + } + + } else if (wstart == 0) { + // 左侧 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (k >= 3 - wend) { + fake_input[3 * j + k] = + input_data[(j + hstart) * input_width + + (k - (3 - wend))]; + } + } + } + + } else if (wend == input_width) { + // 右侧 + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + if (k <= input_width - wstart - 1) { + fake_input[3 * j + k] = + input_data[(j + hstart) * input_width + k + wstart]; + } + } + } + } + for (int l = 0; l < 9; ++l) { + result += fake_input[l] * filter1[l]; + } + if (if_bias) { + output_data[ph * output_width + pw] += result; + } else { + output_data[ph * output_width + pw] = result; + } + + } else { +#if defined(ARMV17) + asm volatile( + + "vld1.32 {q1}, [%[pos1]] \n\t" + "vld1.32 {q4}, [%[filter1]] \n\t" + "vmov.f32 q0, #0.0 \n\t" + + "vld1.32 {q2}, [%[pos2]] \n\t" + "vld1.32 {q5}, [%[filter2]] \n\t" + "vmla.f32 q0, q1, q4 \n\t" + + "vld1.32 {q3}, [%[pos3]] \n\t" + "vld1.32 {q6}, [%[filter3]] \n\t" + + "vmla.f32 q0, q2, q5 \n\t" + "vmla.f32 q0, q3, q6 \n\t" + + "vmov.f32 d1[1], %[zero] \n\t" + + "vadd.f32 d4, d0, d1 \n\t" + "vadd.f32 s10, s8, s9 \n\t" + "vst1.32 {d5[0]},[%[output_ptr]] \n\t" + : + : [input_data] "r"(input_data), [pos1] "r"(pos1), + [pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1), + [filter2] "r"(filter2), [filter3] "r"(filter3), + [output_ptr] "r"(output_ptr), [zero] "r"(zero) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); +#else + const float32x4_t data1 = vld1q_f32(pos1); + const float32x4_t data2 = vld1q_f32(pos2); + const float32x4_t data3 = vld1q_f32(pos3); + + const float32x4_t v_filter1 = vld1q_f32(filter1); + const float32x4_t v_filter2 = vld1q_f32(filter2); + const float32x4_t v_filter3 = vld1q_f32(filter3); + float32x4_t mula = vmulq_f32(data1, v_filter1); + mula = vmlaq_f32(mula, data2, v_filter2); + mula = vmlaq_f32(mula, data3, v_filter3); + float32x2_t res = vpadd_f32( + vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula)); + res = vpadd_f32(res, res); + if (if_bias) { + output_data[ph * output_width + pw] += vget_lane_f32(res, 0); + } else { + output_data[ph * output_width + pw] = vget_lane_f32(res, 0); + } +#endif + } + } + } + input_data += input_channel_stride; + output_data += output_channel_stride; + filter_data += filter_channel_stride; + } + input_data += input_batch_stride; + output_data += output_batch_stride; + } +#endif +} + +void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, + Tensor *output, Tensor *bias, bool if_bias) { + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *bias_data = bias->data(); + + const int h = static_cast(input->dims()[2]); + const int w = static_cast(input->dims()[3]); + const int l = h; + + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const int hxw = h * w; + float32x4_t vbias = vdupq_n_f32(0.0); + for (int b = 0; b < batch_size; ++b) { + const float *filter_data_tmp = filter_data; + + for (int j = 0; j < c; ++j) { + if (if_bias) { + vbias = vdupq_n_f32(bias_data[j]); + } + + int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + output_data[0] = w11 * input_data[0] + w12 * input_data[1] + + w21 * input_data[l] + w22 * input_data[l + 1] + + bias_data[j]; + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + + w20 * input_data[2 * l - 2] + + w21 * input_data[2 * l - 1] + bias_data[j]; + output_data[(l - 1) * l] = + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] + + bias_data[j]; + output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + + w01 * input_data[(l - 2) * (l + 1) + 1] + + w10 * input_data[l * l - 2] + + w11 * input_data[l * l - 1] + bias_data[j]; + + for (int i = 1; i < l - 1; ++i) { + output_data[i * l] = + w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] + + bias_data[j]; + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + + w01 * input_data[i * l + l - 1 - l] + + w10 * input_data[i * l + l - 1 - 1] + + w11 * input_data[i * l + l - 1] + + w20 * input_data[i * l + l - 1 + l - 1] + + w21 * input_data[i * l + l - 1 + l] + + bias_data[j]; + } + + // top 1 row and bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, out0; + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + l); + const float *input_tmp_end = input_tmp + (l - 2) * l; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + l); + int c_mid = l_mid; + auto output_ptr = output_data + 1; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + l + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vaddq_f32(out0, vbias); + + vst1q_f32(output_ptr, out0); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + l + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + out0 = vmulq_n_f32(in4, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vaddq_f32(out0, vbias); + + vst1q_f32(output_ptr + (l - 1) * l, out0); + + // can optimize to each 8 stride. + input_tmp += 4; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; + } + + // top right pad + float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vaddq_f32(out0, vbias); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + + // bottom right pad + float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); + float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 1); + tmp3 = vextq_f32(in6, pad3, 2); + + out0 = vmulq_n_f32(in4, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vaddq_f32(out0, vbias); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + } + } + // mid + + for (int i = 0; i < l - 2; ++i) { + auto output_ptr = output_data + (i + 1) * l + 1; + input_tmp = input_data + i * l; + auto in0_tmp = vld1q_f32(input_tmp); + auto in2_tmp = vld1q_f32(input_tmp + l); + auto in4_tmp = vld1q_f32(input_tmp + l + l); + c_mid = l_mid; + for (; c_mid > 3; c_mid -= 4) { + auto in1_tmp = vld1q_f32(input_tmp + 4); + auto in3_tmp = vld1q_f32(input_tmp + l + 4); + auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + + tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); + tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); + tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); + tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); + tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); + tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vaddq_f32(out0, vbias); + + vst1q_f32(output_ptr, out0); + + output_ptr += 4; + input_tmp += 4; + in0_tmp = in1_tmp; + in2_tmp = in3_tmp; + in4_tmp = in5_tmp; + } + + float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + + tmp0 = vextq_f32(in0_tmp, pad0, 1); + tmp1 = vextq_f32(in0_tmp, pad0, 2); + tmp2 = vextq_f32(in2_tmp, pad1, 1); + tmp3 = vextq_f32(in2_tmp, pad1, 2); + tmp4 = vextq_f32(in4_tmp, pad2, 1); + tmp5 = vextq_f32(in4_tmp, pad2, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vaddq_f32(out0, vbias); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + } + output_data += hxw; + input_data += hxw; + filter_data_tmp += 9; + } + } +} +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.h b/src/operators/math/depthwise_conv_3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..ab2a04369e1fc6e984ffa6f8f5667dd2a10e2a55 --- /dev/null +++ b/src/operators/math/depthwise_conv_3x3.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "framework/tensor.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { +namespace math { +using framework::Tensor; +using std::max; +using std::min; +using std::vector; + +void DepthwiseConv3x3(const Tensor *input, vector strides, + vector paddings, const Tensor *filter, Tensor *bias, + Tensor *output, bool if_bias); +void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, + Tensor *output, Tensor *bias, bool if_bias); +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/test/net/test_mobilenet.cpp b/test/net/test_mobilenet.cpp index 8400b08f2292bb5655e2d85298acce603e1ce603..2495fb497e679d75128f3a74fdbb8da98b927f9f 100644 --- a/test/net/test_mobilenet.cpp +++ b/test/net/test_mobilenet.cpp @@ -33,12 +33,8 @@ int main() { input_tensor.data() + input_tensor.numel()); auto time3 = time(); auto vec_result = executor.Predict(input, dims); - float sum = 0; - for (const auto item : vec_result) { - sum += item; - } - DLOG << "mobilenet output sum =" << sum; auto time4 = time(); + DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; return 0; } diff --git a/test/test_helper.h b/test/test_helper.h index fe720ded8270f2bc02a4f1e72625954962184069..81ad23ff3b4e53db0225630eebaa34878ad4c139 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include #include #include +#include "common/common.h" #include "common/log.h" #include "framework/ddim.h" #include "framework/tensor.h" @@ -35,17 +35,6 @@ static const std::string g_test_image_1x3x224x224 = using paddle_mobile::framework::DDim; using paddle_mobile::framework::Tensor; -using Time = decltype(std::chrono::high_resolution_clock::now()); - -Time time() { return std::chrono::high_resolution_clock::now(); } - -double time_diff(Time t1, Time t2) { - typedef std::chrono::microseconds ms; - auto diff = t2 - t1; - ms counter = std::chrono::duration_cast(diff); - return counter.count() / 1000.0; -} - template void SetupTensor(paddle_mobile::framework::Tensor *input, paddle_mobile::framework::DDim dims, T lower, T upper) {