提交 6df374b4 编写于 作者: Y Yao,kun

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into develop

......@@ -7,11 +7,20 @@ option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON)
# select the platform to build
option(CPU "armv7 with neon" ON)
option(MALI_GPU "mali gpu" ON)
option(MALI_GPU "mali gpu" OFF)
option(FPGA "fpga" OFF)
set(DEBUGING ON)
file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c)
file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h)
if (CPU)
add_definitions(-DPADDLE_MOBILE_CPU)
else()
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.cc)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/arm/*.cpp)
endif()
if (MALI_GPU)
......@@ -27,15 +36,24 @@ if (MALI_GPU)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -L${ACL_ROOT}/build/opencl-1.2-stubs")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lOpenCL")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ACL=1")
else()
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.cc)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/mali/*.cpp)
endif()
if(FPGA)
add_definitions(-DPADDLE_MOBILE_FPGA)
add_definitions(-DPADDLE_MOBILE_FPGA)
else()
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.cc)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/kernel/fpga/*.cpp)
endif()
set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}")
if (DEBUGING)
message(STATUS "debug")
set(CMAKE_BUILD_TYPE Debug)
......@@ -69,8 +87,7 @@ if(USE_OPENMP)
endif()
file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c)
file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h)
if (NOT ANDROID_NDK_TOOLCHAIN_INCLUDED)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/jni/*.cpp)
......
......@@ -28,6 +28,6 @@ RUN apt-get autoremove -y && apt-get clean
RUN pip install --upgrade pip
RUN pip install wheel && pip install pre-commit
RUN ln -s clang-format-5.0 /usr/bin/clang-format
# RUN cd /tmp && curl -O http://mirrors.neusoft.edu.cn/android/repository/android-ndk-r17b-linux-x86_64.zip
# RUN cd /opt && unzip /tmp/android-ndk-r17b-linux-x86_64.zip
# ENV NDK_ROOT /opt/android-ndk-r17b
RUN cd /tmp && curl -O http://mirrors.neusoft.edu.cn/android/repository/android-ndk-r17b-linux-x86_64.zip
RUN cd /opt && unzip /tmp/android-ndk-r17b-linux-x86_64.zip
ENV NDK_ROOT /opt/android-ndk-r17b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <chrono>
using Time = decltype(std::chrono::high_resolution_clock::now());
inline Time time() { return std::chrono::high_resolution_clock::now(); }
inline double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
......@@ -120,7 +120,7 @@ struct ToLog {
if (level > paddle_mobile::log_level) { \
} else \
paddle_mobile::ToLog( \
level, static_cast<std::stringstream &>( \
level, static_cast<const std::stringstream &>( \
std::stringstream() \
<< "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
......@@ -133,7 +133,7 @@ struct ToLog {
} else \
paddle_mobile::ToLog( \
paddle_mobile::kLOG_DEBUG, \
static_cast<std::stringstream &>( \
static_cast<const std::stringstream &>( \
std::stringstream() \
<< "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
......
......@@ -29,7 +29,7 @@ class Program {
std::shared_ptr<Scope> scope;
std::string model_path;
std::string para_path;
bool is_commbine = false;
bool combined = false;
private:
};
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "common/enforce.h"
#include <fstream>
#include "common/enforce.h"
#include "framework/data_layout.h"
#include "framework/ddim.h"
......@@ -131,6 +132,22 @@ class Tensor {
return reinterpret_cast<T *>(mutable_data(typeid(T)));
}
#ifdef PADDLE_MOBILE_DEBUG
template <typename T>
inline void dump(std::string filename) const {
const T *dataptr = data<T>();
std::ofstream out(filename.c_str());
for (int i = 0; i < numel(); ++i) {
out << dataptr[i] << " ";
}
out << "形状:";
for (int j = 0; j < dims_.size(); ++j) {
out << dims_[j] << " ";
}
out.close();
}
#endif
inline void *mutable_data(std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
......
......@@ -88,7 +88,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
bool optimize) {
auto program = this->LoadProgram(model_path, optimize);
program.para_path = para_path;
program.is_commbine = true;
program.combined = true;
return program;
}
......@@ -193,7 +193,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
#endif
}
}
if (program_.is_commbine) {
if (program_.combined) {
InitCombineMemory();
} else {
InitMemory();
......
......@@ -11,9 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define FUSION_CONVADD_OP
#ifdef FUSION_CONVADD_OP
#ifdef FUSION_CONVADD_OP
#pragma once
#include <string>
......
......@@ -23,8 +23,7 @@ bool ConvAddKernel<CPU, float>::Init(const FusionConvAddParam &para) const {
return true;
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
void ConvAddBasic(const FusionConvAddParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
......@@ -102,7 +101,6 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
......@@ -112,6 +110,26 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
}
}
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), param.Bias(), param.Output(), true);
} else {
ConvAddBasic(param);
}
}
template class ConvAddKernel<CPU, float>;
} // namespace operators
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#pragma once
#include <cmath>
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -20,9 +20,11 @@ limitations under the License. */
#if __ARM_NEON
#include <arm_neon.h>
#endif
#include "common/common.h"
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
if (if_bias) {
math::expand_bias(*bias, 1, output->dims());
output->ShareDataWith(*bias);
}
float *output_data = output->mutable_data<float>();
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const int filter_batch_stride = output_channels * output_channel_stride;
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
filter1 = filter_data;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, input_height);
wend = min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
if (if_bias) {
output_data[ph * output_width + pw] += result;
} else {
output_data[ph * output_width + pw] = result;
}
} else {
#if defined(ARMV17)
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q4}, [%[filter1]] \n\t"
"vmov.f32 q0, #0.0 \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
"vld1.32 {q5}, [%[filter2]] \n\t"
"vmla.f32 q0, q1, q4 \n\t"
"vld1.32 {q3}, [%[pos3]] \n\t"
"vld1.32 {q6}, [%[filter3]] \n\t"
"vmla.f32 q0, q2, q5 \n\t"
"vmla.f32 q0, q3, q6 \n\t"
"vmov.f32 d1[1], %[zero] \n\t"
"vadd.f32 d4, d0, d1 \n\t"
"vadd.f32 s10, s8, s9 \n\t"
"vst1.32 {d5[0]},[%[output_ptr]] \n\t"
:
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1),
[filter2] "r"(filter2), [filter3] "r"(filter3),
[output_ptr] "r"(output_ptr), [zero] "r"(zero)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
#else
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#endif
}
}
}
input_data += input_channel_stride;
output_data += output_channel_stride;
filter_data += filter_channel_stride;
}
input_data += input_batch_stride;
output_data += output_batch_stride;
}
#endif
}
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias) {
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vbias = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1] +
bias_data[j];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j];
output_data[(l - 1) * l] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] +
bias_data[j];
output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j];
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] +
bias_data[j];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l] +
bias_data[j];
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::max;
using std::min;
using std::vector;
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -33,12 +33,8 @@ int main() {
input_tensor.data<float>() + input_tensor.numel());
auto time3 = time();
auto vec_result = executor.Predict(input, dims);
float sum = 0;
for (const auto item : vec_result) {
sum += item;
}
DLOG << "mobilenet output sum =" << sum;
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms";
return 0;
}
......@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once
#include <chrono>
#include <fstream>
#include <random>
#include "common/common.h"
#include "common/log.h"
#include "framework/ddim.h"
#include "framework/tensor.h"
......@@ -35,17 +35,6 @@ static const std::string g_test_image_1x3x224x224 =
using paddle_mobile::framework::DDim;
using paddle_mobile::framework::Tensor;
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); }
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
template <typename T>
void SetupTensor(paddle_mobile::framework::Tensor *input,
paddle_mobile::framework::DDim dims, T lower, T upper) {
......
......@@ -37,7 +37,7 @@
# ANDROID_DISABLE_FORMAT_STRING_CHECKS
# ANDROID_CCACHE
cmake_minimum_required(VERSION 3.6.0)
# cmake_minimum_required(VERSION 3.6.0)
# Inhibit all of CMake's own NDK handling code.
set(CMAKE_SYSTEM_VERSION 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册