未验证 提交 b025553b 编写于 作者: H hong19860320 提交者: GitHub

1.update CPUContext to probe CPU info 2.improve the performance of...

1.update CPUContext to probe CPU info 2.improve the performance of SlidingwindowConv3x3s1 and SlidingwindowConv3x3s2 (#1655)
上级 b53a20aa
......@@ -111,6 +111,26 @@ enum PoolingType {
FIRST = 3,
};
enum PowerMode {
PERFORMANCE_PRIORITY = 0, // let threads run on big cores if
// thread_num <= big_cores_num,
// otherwise the power mode will be
// set to AUTO and all threads are
// scheduled by system
EFFICIENCY_PRIORITY = 1, // let threads run on little cores if
// thread_num <= little_cores_num,
// otherwise the power mode will be
// set to AUTO and all threads are
// scheduled by system
PERFORMANCE_ONLY = 2, // force threads run on big cores,
// and the remains are ignored if
// exceed the number big cores
EFFICIENCY_ONLY = 3, // force threads run on little cores,
// and the remains are ignored if
// exceed the number of little cores
AUTO = 4, // scheduled by system
};
struct PaddleMobileConfigInternal {
bool load_when_predict = false;
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/context.h"
#include <iostream>
#include <string>
#include "common/log.h"
#ifdef __APPLE__
#include "TargetConditionals.h"
#ifdef TARGET_OS_IPHONE
// iOS
#elif TARGET_OS_MAC
// Mac OS
#else
// Unsupported platform
#endif
#include <mach/machine.h>
#include <sys/sysctl.h>
#include <sys/types.h>
#else // Linux or Android
#include <sys/syscall.h>
#include <unistd.h>
#endif
namespace paddle_mobile {
namespace framework {
const int DEFAULT_L1_CACHE_SIZE = 32 * 1024;
const int DEFAULT_L2_CACHE_SIZE = 2048 * 1024;
const int DEFAULT_L3_CACHE_SIZE = 0;
void fill_cpu_cache_size(std::vector<int> *cpu_cache_sizes, int value,
const std::vector<int> cpu_ids = {}) {
int num = cpu_ids.size();
if (num > 0) {
for (int i = 0; i < num; i++) {
(*cpu_cache_sizes)[cpu_ids[i]] = value;
}
} else {
num = cpu_cache_sizes->size();
for (int i = 0; i < num; i++) {
(*cpu_cache_sizes)[i] = value;
}
}
}
int get_cpu_num() {
#ifdef __APPLE__
int count = 0;
size_t len = sizeof(count);
sysctlbyname("hw.ncpu", &count, &len, NULL, 0);
if (count < 1) {
count = 1;
}
return count;
#else // Linux or Android
// get cpu num from /sys/devices/system/cpu/cpunum/uevent
int max_cpu_num = 20;
int count = 0;
for (int i = 0; i < max_cpu_num; i++) {
char path[256];
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i);
FILE *fp = fopen(path, "rb");
if (!fp) {
break;
}
count++;
fclose(fp);
}
if (count < 1) {
count = 1;
}
return count;
#endif
}
#if !defined(__APPLE__) // Linux or Android
std::string get_cpu_name() {
FILE *fp = fopen("/proc/cpuinfo", "rb");
if (!fp) {
return "";
}
char line[1024];
while (!feof(fp)) {
char *s = fgets(line, 1024, fp);
if (!s) {
break;
}
if (strstr(line, "Hardware") != NULL) {
fclose(fp);
return std::string(line);
}
}
fclose(fp);
return "";
}
int get_cpu_max_freq_khz(int cpu_id) {
// first try, for all possible cpu
char path[256];
#ifdef __ANDROID__
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpu_id);
FILE *fp = fopen(path, "rb");
if (!fp) {
// second try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state",
cpu_id);
fp = fopen(path, "rb");
if (!fp) {
// third try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq",
cpu_id);
fp = fopen(path, "rb");
if (!fp) {
return 0;
}
int max_freq_khz = 0;
if (fscanf(fp, "%d", &max_freq_khz) <= 0) {
max_freq_khz = 0;
}
fclose(fp);
return max_freq_khz;
}
}
int max_freq_khz = 0;
while (!feof(fp)) {
int freq_khz = 0;
int nscan = fscanf(fp, "%d %*d", &freq_khz);
if (nscan != 1) {
break;
}
if (freq_khz > max_freq_khz) {
max_freq_khz = freq_khz;
}
}
fclose(fp);
return max_freq_khz;
#else
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/scaling_max_freq", cpu_id);
FILE *fp = fopen(path, "r");
if (!fp) {
return 0;
}
int max_freq_khz = 0;
if (fscanf(fp, "%d", &max_freq_khz) <= 0) {
max_freq_khz = 0;
}
fclose(fp);
return max_freq_khz;
#endif
}
void get_cpu_cache_size(int cpu_id, int *l1_cache_size, int *l2_cache_size,
int *l3_cache_size) {
int max_cache_idx_num = 10;
*l1_cache_size = DEFAULT_L1_CACHE_SIZE;
*l2_cache_size = DEFAULT_L2_CACHE_SIZE;
*l3_cache_size = DEFAULT_L3_CACHE_SIZE;
for (int i = 0; i < max_cache_idx_num; i++) {
char path[256];
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cache/index%d/level", cpu_id, i);
FILE *fp = fopen(path, "rb");
if (fp) {
int level = -1;
fscanf(fp, "%d", &level);
fclose(fp);
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cache/index%d/size", cpu_id, i);
fp = fopen(path, "rb");
if (fp) {
int size = -1;
fscanf(fp, "%d", &size);
fclose(fp);
if (size >= 0) {
if (level == 1) {
*l1_cache_size = size * 1024;
} else if (level == 2) {
*l2_cache_size = size * 1024;
} else if (level == 3) {
*l3_cache_size = size * 1024;
}
}
}
}
}
}
int check_online(std::vector<int> *cpu_ids) {
if (cpu_ids->size() == 0) {
return 0;
}
std::vector<int> online_cpu_ids;
char path[256];
for (int i = 0; i < cpu_ids->size(); i++) {
int cpu_id = (*cpu_ids)[i];
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online",
cpu_id);
FILE *fp = fopen(path, "rb");
if (fp) {
int is_online = 0;
fscanf(fp, "%d", &is_online);
fclose(fp);
if (is_online != 0) {
online_cpu_ids.push_back(cpu_id);
}
}
// open failed(Permission denied)
}
*cpu_ids = online_cpu_ids;
return cpu_ids->size();
}
int set_sched_affinity(const std::vector<int> &cpu_ids) {
// cpu_set_t definition
// ref http://stackoverflow.com/questions/16319725/android-set-thread-affinity
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof(unsigned long))
typedef struct {
unsigned long __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t))
// set affinity for thread
#ifdef __GLIBC__
pid_t pid = syscall(SYS_gettid);
#else
pid_t pid = gettid();
#endif
cpu_set_t mask;
CPU_ZERO(&mask);
for (int i = 0; i < cpu_ids.size(); i++) {
CPU_SET(cpu_ids[i], &mask);
}
int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask);
if (syscallret) {
LOG(kLOG_WARNING) << "invoke syscall(__NR_sched_setaffinity) error(ret="
<< syscallret << ")";
return -1;
}
return 0;
}
int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids,
std::vector<int> *l1_cache_sizes,
std::vector<int> *l2_cache_sizes,
std::vector<int> *l3_cache_sizes,
std::string hardware_name) {
/* Snapdragon */
if (hardware_name.find("SDM845") != std::string::npos) { // 845
*cpu_num = 8;
*big_core_ids = {4, 5, 6, 7};
*little_core_ids = {0, 1, 2, 3};
l1_cache_sizes->resize(*cpu_num);
l2_cache_sizes->resize(*cpu_num);
l3_cache_sizes->resize(*cpu_num);
fill_cpu_cache_size(l1_cache_sizes, 64 * 1024);
fill_cpu_cache_size(l2_cache_sizes, 256 * 1024, *big_core_ids);
fill_cpu_cache_size(l2_cache_sizes, 128 * 1024, *little_core_ids);
fill_cpu_cache_size(l3_cache_sizes, 2048 * 1024);
return 0;
} else if (hardware_name.find("SDM710") != std::string::npos) { // 710
*cpu_num = 8;
*big_core_ids = {6, 7};
*little_core_ids = {0, 1, 2, 3, 4, 5};
l1_cache_sizes->resize(*cpu_num);
l2_cache_sizes->resize(*cpu_num);
l3_cache_sizes->resize(*cpu_num);
fill_cpu_cache_size(l1_cache_sizes, 64 * 1024, *big_core_ids);
fill_cpu_cache_size(l1_cache_sizes, 32 * 1024, *little_core_ids);
fill_cpu_cache_size(l2_cache_sizes, 256 * 1024, *big_core_ids);
fill_cpu_cache_size(l2_cache_sizes, 128 * 1024, *little_core_ids);
fill_cpu_cache_size(l3_cache_sizes, 1024 * 1024);
return 0;
} else if (hardware_name.find("MSM8998") != std::string::npos) { // 835
*cpu_num = 8;
*big_core_ids = {4, 5, 6, 7};
*little_core_ids = {0, 1, 2, 3};
l1_cache_sizes->resize(*cpu_num);
l2_cache_sizes->resize(*cpu_num);
l3_cache_sizes->resize(*cpu_num);
fill_cpu_cache_size(l1_cache_sizes, 64 * 1024, *big_core_ids);
fill_cpu_cache_size(l1_cache_sizes, 32 * 1024, *little_core_ids);
// real L2 cache size is 2M, while that will get bad performace on conv3x3s1
// or gemm, set to 1M or 512K
// fill_cpu_cache_size(l2_cache_sizes, 2048 *1024,
// *big_core_ids);
// fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024,
// *little_core_ids);
fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024);
fill_cpu_cache_size(l3_cache_sizes, 0);
return 0;
} else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653
*cpu_num = 8;
*big_core_ids = {0, 1, 2, 3, 4, 5, 6, 7};
*little_core_ids = {};
l1_cache_sizes->resize(*cpu_num);
l2_cache_sizes->resize(*cpu_num);
l3_cache_sizes->resize(*cpu_num);
fill_cpu_cache_size(l1_cache_sizes, 32 * 1024);
fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024);
fill_cpu_cache_size(l3_cache_sizes, 0);
return 0;
}
return -1;
}
// divide cpu cores into big and little clusters by max frequency
void get_cpu_info_by_probe(int cpu_num, std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids,
std::vector<int> *l1_cache_sizes,
std::vector<int> *l2_cache_sizes,
std::vector<int> *l3_cache_sizes) {
// get maxium & minium of cpu_max_freqs
std::vector<int> cpu_max_freqs(cpu_num);
for (int i = 0; i < cpu_num; i++) {
cpu_max_freqs[i] = get_cpu_max_freq_khz(i) / 1000;
}
int max_cpu_max_freq = cpu_max_freqs[0];
int min_cpu_max_freq = cpu_max_freqs[0];
for (int i = 1; i < cpu_num; i++) {
int cur_cpu_max_freq = cpu_max_freqs[i];
if (cur_cpu_max_freq < min_cpu_max_freq) {
min_cpu_max_freq = cur_cpu_max_freq;
} else if (cur_cpu_max_freq > max_cpu_max_freq) {
max_cpu_max_freq = cur_cpu_max_freq;
}
}
int mid_max_freq_khz = (max_cpu_max_freq + min_cpu_max_freq) / 2;
big_core_ids->clear();
little_core_ids->clear();
for (int i = 0; i < cpu_num; i++) {
if (cpu_max_freqs[i] >= mid_max_freq_khz) {
big_core_ids->push_back(i);
} else {
little_core_ids->push_back(i);
}
}
/* get l1, l2, l3 cache size for each core */
l1_cache_sizes->resize(cpu_num);
l2_cache_sizes->resize(cpu_num);
l3_cache_sizes->resize(cpu_num);
for (int i = 0; i < cpu_num; i++) {
get_cpu_cache_size(i, &((*l1_cache_sizes)[i]), &((*l2_cache_sizes)[i]),
&((*l3_cache_sizes)[i]));
}
}
void bind_threads(const std::vector<int> &cpu_ids) {
#ifdef _OPENMP
int num_threads = omp_get_max_threads();
std::vector<int> ssarets;
for (int i = 0; i < num_threads; i++) {
ssarets.push_back(0);
}
#pragma omp parallel for
for (int i = 0; i < num_threads; i++) {
ssarets[i] = set_sched_affinity(cpu_ids);
}
for (int i = 0; i < num_threads; i++) {
if (ssarets[i] != 0) {
LOG(kLOG_WARNING) << "set cpu affinity failed, thread idx: " << i;
return;
}
}
#else
int ssaret = set_sched_affinity(cpu_ids);
if (ssaret != 0) {
LOG(kLOG_WARNING) << "set cpu affinity failed, thread idx: 0 ";
return;
}
#endif
}
#endif
CPUContext::CPUContext() {
_cpu_num = get_cpu_num();
_big_core_ids.clear();
_little_core_ids.clear();
#ifdef __APPLE__
// set default L1, L2 and L3 cache sizes
_l1_cache_sizes.resize(_cpu_num);
_l2_cache_sizes.resize(_cpu_num);
_l3_cache_sizes.resize(_cpu_num);
fill_cpu_cache_size(&_l1_cache_sizes, DEFAULT_L1_CACHE_SIZE);
fill_cpu_cache_size(&_l2_cache_sizes, DEFAULT_L2_CACHE_SIZE);
fill_cpu_cache_size(&_l3_cache_sizes, DEFAULT_L3_CACHE_SIZE);
#else // Linux or Android
// probe cpu info, and set big&litte clusters, L1, L2 and L3 cache sizes
std::string cpu_name = get_cpu_name();
bool failed =
get_cpu_info_by_name(&_cpu_num, &_big_core_ids, &_little_core_ids,
&_l1_cache_sizes, &_l2_cache_sizes, &_l3_cache_sizes,
cpu_name) != 0;
if (failed) {
get_cpu_info_by_probe(_cpu_num, &_big_core_ids, &_little_core_ids,
&_l1_cache_sizes, &_l2_cache_sizes, &_l3_cache_sizes);
}
LOG(kLOG_INFO) << "CPU num: " << _cpu_num;
for (int i = 0; i < _cpu_num; i++) {
LOG(kLOG_INFO) << i << " L1 Cache: " << _l1_cache_sizes[i] << "KB"
<< " L2 Cache: " << _l2_cache_sizes[i] << "KB"
<< " L3 Cache: " << _l3_cache_sizes[i] << "KB";
}
LOG(kLOG_INFO) << "Big cores: ";
for (int i = 0; i < _big_core_ids.size(); i++) {
LOG(kLOG_INFO) << _big_core_ids[i];
}
LOG(kLOG_INFO) << "Little cores: ";
for (int i = 0; i < _little_core_ids.size(); i++) {
LOG(kLOG_INFO) << _little_core_ids[i];
}
#endif
// use single thread by default
set_thread_num(1, PERFORMANCE_PRIORITY);
}
void CPUContext::set_thread_num(int thread_num, PowerMode power_mode) {
int big_core_num = _big_core_ids.size();
int little_core_num = _little_core_ids.size();
#ifdef _OPENMP
if (thread_num > _cpu_num) {
thread_num = _cpu_num;
}
#else
thread_num = 1;
#endif
std::vector<int> bind_core_ids;
if (power_mode == PERFORMANCE_PRIORITY || power_mode == PERFORMANCE_ONLY) {
if (big_core_num > 0) {
bind_core_ids = _big_core_ids;
if (power_mode == PERFORMANCE_ONLY && thread_num > big_core_num) {
LOG(kLOG_ERROR) << "thread_num(" << thread_num
<< ") exceed the big cores num (" << big_core_num << ")"
<< ", force to set thread_num = " << big_core_num;
thread_num = big_core_num;
}
}
} else if (power_mode == EFFICIENCY_PRIORITY ||
power_mode == EFFICIENCY_ONLY) {
if (little_core_num > 0) {
bind_core_ids = _little_core_ids;
if (power_mode == EFFICIENCY_ONLY && thread_num > little_core_num) {
LOG(kLOG_ERROR) << "thread_num(" << thread_num
<< ") exceed the little cores num (" << little_core_num
<< ")"
<< ", force to set thread_num = " << little_core_num;
thread_num = little_core_num;
}
}
}
_power_mode = AUTO;
#ifdef _OPENMP
omp_set_num_threads(thread_num);
thread_num = omp_get_max_threads();
#endif
#if !defined(__APPLE__) // Linux or Android
if (bind_core_ids.size() > 0 && check_online(&bind_core_ids) >= thread_num) {
bind_threads(bind_core_ids);
_power_mode = power_mode;
}
#endif
LOG(kLOG_INFO) << "thread num: " << thread_num
<< " power mode: " << _power_mode;
}
int CPUContext::get_thread_num() {
int thread_num = 1;
#ifdef _OPENMP
thread_num = omp_get_max_threads();
#endif
return thread_num;
}
int CPUContext::get_cache_size(int level) {
std::vector<int> *ptr = nullptr;
if (level == 1) {
ptr = &_l1_cache_sizes;
} else if (level == 2) {
ptr = &_l2_cache_sizes;
} else if (level == 3) {
ptr = &_l3_cache_sizes;
} else {
return 0;
}
if (_power_mode == PERFORMANCE_PRIORITY || _power_mode == PERFORMANCE_ONLY) {
return (*ptr)[_big_core_ids[0]];
} else if (_power_mode == EFFICIENCY_PRIORITY ||
_power_mode == EFFICIENCY_ONLY) {
return (*ptr)[_little_core_ids[0]];
} else { // AUTO
return (*ptr)[0];
}
}
void *CPUContext::get_work_space(int size_in_byte) {
return reinterpret_cast<void *>(
_workspace.mutable_data<int8_t>(make_ddim({size_in_byte})));
}
} // namespace framework
} // namespace paddle_mobile
......@@ -18,63 +18,45 @@ limitations under the License. */
#include <omp.h>
#endif
#define MOBILE_MAX_CPU_NUM 8
#include <vector>
#include "framework/tensor.h"
namespace paddle_mobile {
namespace framework {
struct CPUContext {
private:
CPUContext() : num_cpus(4), num_threads(1) {
// TODO(hjchen2)
for (int i = 0; i < num_cpus; ++i) {
cpu_frequencies[i] = 2400; // 2400 MHz
max_cpu_frequencies[i] = 2400; // 2400 MHz
}
// L1_cache = 32000; // 32K
L1_cache = 32 * 1024;
L2_cache = 2000000; // 2M
// L2_cache = 512000;
}
public:
void set_num_threads(int threads) {
#if _ONENMP
omp_set_num_threads(threads);
if (threads <= omp_get_max_threads()) {
num_threads = threads;
} else {
num_threads = omp_get_max_threads();
}
#endif
num_threads = (num_threads > 1) ? num_threads : 1;
}
CPUContext();
virtual ~CPUContext() {}
public:
static CPUContext* Context() {
static CPUContext* ctx = new CPUContext;
static CPUContext* ctx = nullptr;
if (ctx == nullptr) {
ctx = new CPUContext();
}
return ctx;
}
int num_cpus;
int num_threads;
int cpu_frequencies[MOBILE_MAX_CPU_NUM];
int max_cpu_frequencies[MOBILE_MAX_CPU_NUM];
int L1_cache;
int L2_cache;
void set_thread_num(int thread_num,
PowerMode power_mode = PERFORMANCE_PRIORITY);
int get_thread_num();
PowerMode get_power_mode() const { return _power_mode; }
int get_cache_size(int level);
int get_l1_cache_size() { return get_cache_size(1); }
int get_l2_cache_size() { return get_cache_size(2); }
int get_l3_cache_size() { return get_cache_size(3); }
void* get_work_space(int size_in_byte);
int _cpu_num;
PowerMode _power_mode;
std::vector<int> _big_core_ids;
std::vector<int> _little_core_ids;
std::vector<int> _l1_cache_sizes;
std::vector<int> _l2_cache_sizes;
std::vector<int> _l3_cache_sizes;
Tensor _workspace;
};
inline void set_global_num_threads(int threads) {
// CPUContext::Context()->set_num_threads(threads);
CPUContext::Context()->num_threads = threads;
}
inline int get_global_num_threads() {
return CPUContext::Context()->num_threads;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -40,8 +40,8 @@ namespace framework {
#pragma mark - executor
template <typename Device, typename T>
void Executor<Device, T>::SetThreadNum(int threads) {
set_global_num_threads(threads);
void Executor<Device, T>::SetThreadNum(int thread_num, PowerMode power_mode) {
CPUContext::Context()->set_thread_num(thread_num, power_mode);
}
template <typename Device, typename T>
......@@ -440,7 +440,7 @@ std::shared_ptr<LoDTensor> Executor<Device, T>::GetOutput(
template <typename Device, typename T>
PMStatus Executor<Device, T>::Predict() {
#if _OPENMP
omp_set_num_threads(get_global_num_threads());
omp_set_num_threads(CPUContext::Context()->get_thread_num());
#endif
// clear all no persistable tensor array since write_to_array
// is always push back a new tensor in the array
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "common/types.h"
......@@ -37,7 +38,8 @@ class Executor {
paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1,
const bool use_optimize = true, const bool lod_mode = false);
void SetThreadNum(int threads);
void SetThreadNum(int thread_num,
PowerMode power_mode = PERFORMANCE_PRIORITY);
PMStatus Predict(const std::vector<std::pair<std::string, Tensor>> &inputs);
PMStatus Predict(
......
......@@ -29,8 +29,9 @@ limitations under the License. */
namespace paddle_mobile {
template <typename Device, typename T>
void PaddleMobile<Device, T>::SetThreadNum(int num) {
executor_->SetThreadNum(num);
void PaddleMobile<Device, T>::SetThreadNum(int thread_num,
PowerMode power_mode) {
executor_->SetThreadNum(thread_num, power_mode);
}
template <typename Device, typename T>
......
......@@ -83,7 +83,8 @@ class PaddleMobile {
bool quantification = false, int batch_size = 1,
bool lod_mode = false);
void SetThreadNum(int count);
void SetThreadNum(int thread_num,
PowerMode power_mode = PERFORMANCE_PRIORITY);
void Clear();
double GetPredictTime();
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/math/slidingwindow_utils.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
......@@ -56,38 +57,31 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 1
&& (param->Input()->dims()[1] >= 8 &&
param->Output()->dims()[1] >= 8)
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
param->Strides()[0] == 1 && param->Dilations()[0] == 1) {
// transform weight
Variable *transformed_var = param->GetScope()->Var();
param->transformed_filter_ =
transformed_var->GetMutable<framework::LoDTensor>();
operators::math::winograd_transform_weight<8, 3>(
*param->Filter(), param->transformed_filter_);
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 1
&& (param->Input()->dims()[2] >= 48 &&
param->Output()->dims()[1] <= 24)
#endif
) {
if (param->Input()->dims()[1] >= 32 && param->Output()->dims()[1] >= 32 &&
param->Output()->dims()[2] > 16 && param->Output()->dims()[3] > 16) {
math::winograd_transform_weight<8, 3>(*param->Filter(),
param->transformed_filter_);
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
} else {
math::slidingwindow_transform_weight<float>(*param->Filter(),
param->transformed_filter_);
param->ExecMode() = ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT;
}
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 2 && param->Dilations()[0] == 1
#if 1
&& (param->Input()->dims()[2] >= 48 &&
param->Output()->dims()[1] <= 24)
#endif
) {
param->Strides()[0] == 2 && param->Dilations()[0] == 1) {
// transform weight
Variable *transformed_var = param->GetScope()->Var();
param->transformed_filter_ =
transformed_var->GetMutable<framework::LoDTensor>();
math::slidingwindow_transform_weight<float>(*param->Filter(),
param->transformed_filter_);
param->ExecMode() = ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......
......@@ -243,9 +243,15 @@ void SlidingwindowConv3x3(const ConvParam<CPU> &param) {
output->mutable_data<Otype>();
if (strides[0] == 1) {
math::SlidingwindowConv3x3s1<Itype, Otype>(input, filter, paddings, output);
// math::SlidingwindowConv3x3s1<Itype, Otype>(input, filter, paddings,
// output);
math::SlidingwindowConv3x3s1Faster<Itype, Otype>(
input, param.transformed_filter_, paddings, output);
} else if (strides[0] == 2) {
math::SlidingwindowConv3x3s2<Itype, Otype>(input, filter, paddings, output);
// math::SlidingwindowConv3x3s2<Itype, Otype>(input, filter, paddings,
// output);
math::SlidingwindowConv3x3s2Faster<Itype, Otype>(
input, param.transformed_filter_, paddings, output);
} else {
GemmConv<Itype, Otype>(param);
}
......
......@@ -29,8 +29,6 @@ namespace paddle_mobile {
namespace operators {
namespace math {
static framework::CPUContext *g_cpu_ctx = framework::CPUContext::Context();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) {
......@@ -70,11 +68,15 @@ class GemmExecutor : public Executor {
unsigned int L1_size = 0;
unsigned int L2_size = 0;
if (M_ > N_) {
L2_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, M_, K_);
L1_size = g_cpu_ctx->L2_cache;
L2_size =
ResetL1Cache(framework::CPUContext::Context()->get_l1_cache_size(),
num_threads_, M_, K_);
L1_size = framework::CPUContext::Context()->get_l2_cache_size();
} else {
L1_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, N_, K_);
L2_size = g_cpu_ctx->L2_cache;
L1_size =
ResetL1Cache(framework::CPUContext::Context()->get_l1_cache_size(),
num_threads_, N_, K_);
L2_size = framework::CPUContext::Context()->get_l2_cache_size();
}
rhs_tile_num_ = L1_size / (K_ * sizeof(Itype));
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "operators/math/slidingwindow_conv3x3.h"
#include <vector>
#include "framework/context.h"
#include "operators/math/slidingwindow_utils.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif
......@@ -703,7 +705,7 @@ void SlidingwindowConv3x3s1<float, float>(const framework::Tensor *input,
in_ptr3--;
in_ptr4--;
}
#endif //__aarch64__
#endif // __aarch64__
#endif // __ARM_NEON
// remain output_width
......@@ -1250,7 +1252,7 @@ void SlidingwindowConv3x3s1<float, float>(const framework::Tensor *input,
}
}
#endif //__aarch64__
#endif // __aarch64__
#endif // __ARM_NEON
// remain output_width
......@@ -1738,7 +1740,7 @@ void SlidingwindowConv3x3s1<float, float>(const framework::Tensor *input,
in_ptr3--;
in_ptr4--;
}
#endif //__aarch64__
#endif // __aarch64__
#endif // __ARM_NEON
// remain output_width
......@@ -2940,7 +2942,7 @@ void SlidingwindowConv3x3s2<float, float>(const framework::Tensor *input,
in_ptr3 = in_ptr2 + input_w;
}
}
#endif //__aarch64__
#endif // __aarch64__
#endif // __ARM_NEON
// remain output_width
......@@ -3594,7 +3596,7 @@ void SlidingwindowConv3x3s2<float, float>(const framework::Tensor *input,
"q7", "q8", "q10", "q12", "q13", "q14", "q15");
}
}
#endif //__aarch64__
#endif // __aarch64__
#endif // __ARM_NEON
out_ptr1 -= 4;
out_ptr1 += 4;
......@@ -3705,6 +3707,1956 @@ void SlidingwindowConv3x3s2<float, float>(const framework::Tensor *input,
}
}
template <>
void SlidingwindowConv3x3s1Faster<float, float>(
const framework::Tensor *input, framework::Tensor *filter,
const std::vector<int> &paddings, framework::Tensor *output) {
const float *din = input->data<float>();
float *dout = output->mutable_data<float>();
const float *weights = filter->mutable_data<float>();
const float *bias = nullptr;
bool relu = false;
const int num = input->dims()[0];
const int chin = input->dims()[1];
const int hin = input->dims()[2];
const int win = input->dims()[3];
const int chout = output->dims()[1];
const int hout = output->dims()[2];
const int wout = output->dims()[3];
const int pad_h = paddings[0];
const int pad_w = paddings[1];
const int threads = framework::CPUContext::Context()->get_thread_num();
int l2_size =
framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float);
const int hout_c_block = 4;
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round + 2;
int hout_r_block = (l2_size - 2 * win_round * chin) /
(win_round * chin + hout_c_block * wout_round * threads);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block + 2;
float ptr_zero[win_round];
memset(ptr_zero, 0, sizeof(float) * win_round);
float ptr_write[wout_round];
int in_len = win_round * chin;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round;
float *pre_din =
static_cast<float *>(framework::CPUContext::Context()->get_work_space(
(pre_in_size + threads * pre_out_size) * sizeof(float)));
int size_in_channel = win * hin;
int size_out_channel = wout * hout;
int w_stride = chin * 9; // kernel_w * kernel_h;
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h *
int ws = -pad_w;
int we = ws + win_round;
int w_loop = wout_round / 4;
int c_remain = chout - (chout / hout_c_block) * hout_c_block;
int c_round_down = (chout / hout_c_block) * hout_c_block;
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * chin * size_in_channel;
float *dout_batch = dout + n * chout * size_out_channel;
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h - pad_h;
int he = hs + h_kernel + 2;
slidingwindow_prepack_input(din_batch, pre_din, 0, chin, hs, he, ws, we,
chin, win, hin, ptr_zero);
#pragma omp parallel for
for (int c = 0; c < chout - (hout_c_block - 1); c += hout_c_block) {
#ifdef _OPENMP
float *pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
#else
float *pre_out = pre_din + pre_in_size;
#endif
const float *block_inr0 = pre_din;
const float *block_inr1 = block_inr0 + in_len;
const float *block_inr2 = block_inr1 + in_len;
const float *block_inr3 = block_inr2 + in_len;
const float *weight_c = weights + c * w_stride;
const float *bias_ptr = ptr_zero;
if (bias != nullptr) {
bias_ptr = bias + c;
}
slidingwindow_fill_bias(pre_out, bias_ptr,
wout_round * hout_c_block * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const float *wc0 = weight_c;
const float *inr0 = block_inr0;
const float *inr1 = block_inr1;
const float *inr2 = block_inr2;
const float *inr3 = block_inr3;
float *pre_out0 = pre_out + hk * out_row_stride;
float *pre_out1 = pre_out0 + out_row_stride;
#ifdef __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
float32x4_t w0 = vld1q_f32(wc0); // w0, v23
float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24
float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25
float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26
float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27
float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28
float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29
float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30
float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
int cnt = w_loop;
asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/
"2: \n" /* main loop*/
/* r0, r1, mul w0, get out r0, r1 */
"fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/
"fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/
"fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/
"fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/
"fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/
"fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/
"fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/
"fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/
/* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
"fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/
"fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/
"fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/
"fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/
"fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/
"fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/
"fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
/* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/
"fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/
"fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/
"fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/
"fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/
"fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/
"fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/
/* r1, r2, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/
"fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/
"fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/
"fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/
"fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/
"fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/
"fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/
"ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/
/* r1, r2, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
"fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/
"fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/
"fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
/* r1, r2, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/
"fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/
"fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/
"fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/
"fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/
"fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/
"fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/
/* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/
"fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/
"ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/
/* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/
"fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
/* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
"fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
"fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
"stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3),
[w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7),
[w8] "w"(w8)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22");
wc0 += 9 * hout_c_block;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
}
#else // not __aarch64__
for (int i = 0; i < chin; ++i) {
const float *wc0 = weight_c + i * w_stride_chin;
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
int cnt = w_loop;
asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
/* load r0, r1 */
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, "
"4 float\n"
"vld1.32 {d2}, [%[r0]] @ load r0, "
"2 float\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
/* main loop */
"0: @ main "
"loop\n"
/* mul r0 with w0, w1, w2, get out r0 */
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load "
"outr1, w0, w1, c0~c3\n"
"vmla.f32 q8, q5, d0[0] @ w0 * "
"inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load "
"outr1, w2, w3, c0~c3\n"
"vmla.f32 q9, q5, d0[1] @ w0 * "
"inr01\n"
"vmla.f32 q10, q5, d1[0] @ w0 * "
"inr02\n"
"vmla.f32 q11, q5, d1[1] @ w0 * "
"inr03\n"
"vld1.32 {d3-d4}, [%[r1]]! @ load r1, "
"4 float\n"
"vmla.f32 q8, q6, d0[1] @ w1 * "
"inr01\n"
"vmla.f32 q9, q6, d1[0] @ w1 * "
"inr02\n"
"vmla.f32 q10, q6, d1[1] @ w1 * "
"inr03\n"
"vmla.f32 q11, q6, d2[0] @ w1 * "
"inr04\n"
"vld1.32 {d5}, [%[r1]] @ load r0, "
"2 float\n"
"vmla.f32 q8, q7, d1[0] @ w2 * "
"inr02\n"
"vmla.f32 q9, q7, d1[1] @ w2 * "
"inr03\n"
"vmla.f32 q10, q7, d2[0] @ w2 * "
"inr04\n"
"vmla.f32 q11, q7, d2[1] @ w2 * "
"inr05\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 "
"- 32, to start address\n"
/* mul r1 with w0, w1, w2, get out r1 */
"vmla.f32 q12, q5, d3[0] @ w0 * "
"inr10\n"
"vmla.f32 q13, q5, d3[1] @ w0 * "
"inr11\n"
"vmla.f32 q14, q5, d4[0] @ w0 * "
"inr12\n"
"vmla.f32 q15, q5, d4[1] @ w0 * "
"inr13\n"
"vmla.f32 q12, q6, d3[1] @ w1 * "
"inr11\n"
"vmla.f32 q13, q6, d4[0] @ w1 * "
"inr12\n"
"vmla.f32 q14, q6, d4[1] @ w1 * "
"inr13\n"
"vmla.f32 q15, q6, d5[0] @ w1 * "
"inr14\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, "
"w4, to q5, q6\n"
"vmla.f32 q12, q7, d4[0] @ w2 * "
"inr12\n"
"vmla.f32 q13, q7, d4[1] @ w2 * "
"inr13\n"
"vmla.f32 q14, q7, d5[0] @ w2 * "
"inr14\n"
"vmla.f32 q15, q7, d5[1] @ w2 * "
"inr15\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, "
"to q7\n"
/* mul r1 with w3, w4, w5, get out r0 */
"vmla.f32 q8, q5, d3[0] @ w3 * "
"inr10\n"
"vmla.f32 q9, q5, d3[1] @ w3 * "
"inr11\n"
"vmla.f32 q10, q5, d4[0] @ w3 * "
"inr12\n"
"vmla.f32 q11, q5, d4[1] @ w3 * "
"inr13\n"
"vld1.32 {d0-d1}, [%[r2]]! @ load r2, "
"4 float\n"
"vmla.f32 q8, q6, d3[1] @ w4 * "
"inr11\n"
"vmla.f32 q9, q6, d4[0] @ w4 * "
"inr12\n"
"vmla.f32 q10, q6, d4[1] @ w4 * "
"inr13\n"
"vmla.f32 q11, q6, d5[0] @ w4 * "
"inr14\n"
"vld1.32 {d2}, [%[r2]] @ load r2, "
"2 float\n"
"vmla.f32 q8, q7, d4[0] @ w5 * "
"inr12\n"
"vmla.f32 q9, q7, d4[1] @ w5 * "
"inr13\n"
"vmla.f32 q10, q7, d5[0] @ w5 * "
"inr14\n"
"vmla.f32 q11, q7, d5[1] @ w5 * "
"inr15\n"
/* mul r2 with w3, w4, w5, get out r1 */
"vmla.f32 q12, q5, d0[0] @ w3 * "
"inr20\n"
"vmla.f32 q13, q5, d0[1] @ w3 * "
"inr21\n"
"vmla.f32 q14, q5, d1[0] @ w3 * "
"inr22\n"
"vmla.f32 q15, q5, d1[1] @ w3 * "
"inr23\n"
"vmla.f32 q12, q6, d0[1] @ w4 * "
"inr21\n"
"vmla.f32 q13, q6, d1[0] @ w4 * "
"inr22\n"
"vmla.f32 q14, q6, d1[1] @ w4 * "
"inr23\n"
"vmla.f32 q15, q6, d2[0] @ w4 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d1[0] @ w5 * "
"inr22\n"
"vmla.f32 q13, q7, d1[1] @ w5 * "
"inr23\n"
"vmla.f32 q14, q7, d2[0] @ w5 * "
"inr24\n"
"vmla.f32 q15, q7, d2[1] @ w5 * "
"inr25\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
/* mul r2 with w6, w7, w8, get out r0 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d0[1] @ w6 * "
"inr21\n"
"vld1.32 {d3-d4}, [%[r3]]! @ load r3, "
"4 float\n"
"vmla.f32 q10, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q11, q5, d1[1] @ w6 * "
"inr23\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[0] @ w7 * "
"inr22\n"
"vld1.32 {d5}, [%[r3]] @ load r3, "
"2 float\n"
"vmla.f32 q10, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q11, q6, d2[0] @ w7 * "
"inr24\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d1[1] @ w8 * "
"inr23\n"
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, "
"4 float\n"
"vmla.f32 q10, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q11, q7, d2[1] @ w8 * "
"inr25\n"
"vld1.32 {d2}, [%[r0]] @ load r0, "
"2 float\n"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q5, d3[0] @ w6 * "
"inr20\n"
"vmla.f32 q13, q5, d3[1] @ w6 * "
"inr21\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q14, q5, d4[0] @ w6 * "
"inr22\n"
"vmla.f32 q15, q5, d4[1] @ w6 * "
"inr23\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q12, q6, d3[1] @ w7 * "
"inr21\n"
"vmla.f32 q13, q6, d4[0] @ w7 * "
"inr22\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vmla.f32 q14, q6, d4[1] @ w7 * "
"inr23\n"
"vmla.f32 q15, q6, d5[0] @ w7 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d4[0] @ w8 * "
"inr22\n"
"vmla.f32 q13, q7, d4[1] @ w8 * "
"inr23\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"vmla.f32 q14, q7, d5[0] @ w8 * "
"inr24\n"
"vmla.f32 q15, q7, d5[1] @ w8 * "
"inr25\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1), [wc0] "+r"(wc0)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
}
#endif // __aarch64__
block_inr0 = block_inr2;
block_inr1 = block_inr3;
block_inr2 = block_inr1 + in_len;
block_inr3 = block_inr2 + in_len;
}
slidingwindow_writeout_c4_fp32(pre_out, dout_batch, c, c + hout_c_block,
h, h + h_kernel, 0, wout_round, chout,
hout, wout, relu, ptr_write);
}
const float *weight_remain_ptr = weights + c_round_down * w_stride;
#pragma omp parallel for
for (int c = 0; c < c_remain; ++c) {
#ifdef USE_OPENMP
float *pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
#else
float *pre_out = pre_din + pre_in_size;
#endif
int c_idx = c_round_down + c;
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
const float *block_inr0 = pre_din;
const float *block_inr1 = block_inr0 + in_len;
const float *block_inr2 = block_inr1 + in_len;
const float *block_inr3 = block_inr2 + in_len;
const float *bias_ptr = ptr_zero;
if (bias != nullptr) {
bias_ptr = bias + c_idx;
}
slidingwindow_fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const float *wc0 = weight_remain_ptr;
const float *inr0 = block_inr0;
const float *inr1 = block_inr1;
const float *inr2 = block_inr2;
const float *inr3 = block_inr3;
float *pre_out0 = pre_out + hk * wout_round;
float *pre_out1 = pre_out0 + wout_round;
#ifdef __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23
float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24
float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25
float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26
float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27
float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28
float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29
float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30
float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
int cnt = w_loop;
asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"2: \n" /* main loop*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/
"fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/
"ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/
"ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/
"ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/
"ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/
"fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/
"fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/
"fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/
"fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/
"ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/
"ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/
"ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/
"ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/
"fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/
"fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/
"fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/
"fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/
"fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/
"fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/
"str q21, [%[ptr_out0]], #16 \n" /*write output r0*/
"str q22, [%[ptr_out1]], #16 \n" /*write output r1*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3),
[w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7),
[w8] "w"(w8)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v21", "v22");
wc0 += 9 * hout_c_block;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
}
#else // not __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
//! get valid weights of current output channel
float w_tmp[10] = {
wc0[c], wc0[c + 4], wc0[c + 8], wc0[c + 12], wc0[c + 16],
wc0[c + 20], wc0[c + 24], wc0[c + 28], wc0[c + 32], 0.f};
float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0
float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1
float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
int cnt = w_loop / 2;
if (cnt > 0) {
asm volatile(
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 "
"float\n"
"vld1.32 {d10}, [%[r0]] @ load r0, 2 "
"float\n"
/* main loop */
"0: @ main loop\n"
/* r0 * w0, w1, w2, get out r0*/
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, "
"or11\n"
"vext.32 q8, q3, q4, #1 @ r0, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r0, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r0, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r0, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r1]] @ load r1, 2 "
"float\n"
/* r1 * w3, w4, w5, get out r0*/
/* r1 * w0, w1, w2, get out r1*/
"vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, "
"4, 5, 6, 7\n"
"vext.32 q8, q3, q4, #1 @ r1, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r1, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r1, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r1, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, "
"5, 6, 7, 8\n"
"vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, "
"1, 2, 3, 4\n"
"vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, "
"6, 7, 8, 9\n"
"vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r2]] @ load r2, 2 "
"float\n"
/* r2 * w6, w7, w8, get out r0*/
/* r2 * w3, w4, w5, get out r1*/
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, "
"4, 5, 6, 7\n"
"vext.32 q8, q3, q4, #1 @ r2, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r2, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r2, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r2, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, "
"5, 6, 7, 8\n"
"vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, "
"6, 7, 8, 9\n"
"vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r3]] @ load r3, 2 "
"float\n"
/* r3 * w6, w7, w8, get out r1*/
"vext.32 q8, q3, q4, #1 @ r3, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r3, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, "
"4, 5, 6, 7\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, "
"or01\n"
"vext.32 q10, q3, q4, #2 @ r3, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r3, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, "
"4, 5, 6, 7\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 "
"float\n"
"vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r0]] @ load r0, 2 "
"float\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, "
"or11\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1),
[r2] "+r"(r2), [r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc", "memory", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "q14", "q15");
r0 -= 8;
}
//! deal with remain wout
if (w_loop & 1) {
ptr_out0[0] +=
r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] +
r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] +
r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8];
ptr_out0[1] +=
r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] +
r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] +
r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8];
ptr_out0[2] +=
r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] +
r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] +
r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8];
ptr_out0[3] +=
r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] +
r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] +
r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8];
ptr_out1[0] +=
r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] +
r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] +
r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8];
ptr_out1[1] +=
r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] +
r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] +
r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8];
ptr_out1[2] +=
r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] +
r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] +
r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8];
ptr_out1[3] +=
r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] +
r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] +
r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8];
}
wc0 += 36;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
}
#endif // __aarch64__
block_inr0 = block_inr2;
block_inr1 = block_inr3;
block_inr2 = block_inr1 + in_len;
block_inr3 = block_inr2 + in_len;
}
slidingwindow_writeout_c1_fp32(pre_out, dout_batch, c_idx, c_idx + 1, h,
h + h_kernel, 0, wout_round, chout, hout,
wout, relu, ptr_write);
}
}
}
}
template <>
void SlidingwindowConv3x3s2Faster<float, float>(
const framework::Tensor *input, framework::Tensor *filter,
const std::vector<int> &paddings, framework::Tensor *output) {
const float *din = input->data<float>();
float *dout = output->mutable_data<float>();
const float *weights = filter->mutable_data<float>();
const float *bias = nullptr;
bool relu = false;
const int num = input->dims()[0];
const int chin = input->dims()[1];
const int hin = input->dims()[2];
const int win = input->dims()[3];
const int chout = output->dims()[1];
const int hout = output->dims()[2];
const int wout = output->dims()[3];
const int pad_h = paddings[0];
const int pad_w = paddings[1];
const int threads = framework::CPUContext::Context()->get_thread_num();
int l2_size =
framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float);
const int hout_c_block = 4;
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round * 2 /*stride_w*/ + 1;
//! get h block
//! win_round * chin * hin_r_block + wout_round * hout_c_block * hout_r_block
//! * threads = l2_size win_round = 2 * wout_round + 1 hin_r_block = 2 *
//! hout_r_block + 1
int hout_r_block =
(l2_size - 2 * wout_round * chin - chin) /
((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1;
float ptr_zero[win_round];
memset(ptr_zero, 0, sizeof(float) * win_round);
float ptr_write[wout_round];
int in_len = win_round * chin;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round;
float *pre_din =
static_cast<float *>(framework::CPUContext::Context()->get_work_space(
(pre_in_size + threads * pre_out_size) * sizeof(float)));
int size_in_channel = win * hin;
int size_out_channel = wout * hout;
int w_stride = chin * 9; /*kernel_w * kernel_h*/
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h *
int ws = -pad_w;
int we = ws + win_round;
int w_loop = wout_round / 4;
int c_remain = chout - (chout / hout_c_block) * hout_c_block;
int c_round_down = (chout / hout_c_block) * hout_c_block;
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * chin * size_in_channel;
float *dout_batch = dout + n * chout * size_out_channel;
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h * 2 /*stride_h*/ - pad_h;
int he = hs + h_kernel * 2 /*stride_h*/ + 1;
slidingwindow_prepack_input(din_batch, pre_din, 0, chin, hs, he, ws, we,
chin, win, hin, ptr_zero);
const float *cblock_inr0 = pre_din;
const float *cblock_inr1 = cblock_inr0 + in_len;
const float *cblock_inr2 = cblock_inr1 + in_len;
const float *cblock_inr3 = cblock_inr2 + in_len;
const float *cblock_inr4 = cblock_inr3 + in_len;
#pragma omp parallel for
for (int c = 0; c < c_round_down; c += hout_c_block) {
#ifdef _OPENMP
float *pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
#else
float *pre_out = pre_din + pre_in_size;
#endif
const float *block_inr0 = cblock_inr0;
const float *block_inr1 = cblock_inr1;
const float *block_inr2 = cblock_inr2;
const float *block_inr3 = cblock_inr3;
const float *block_inr4 = cblock_inr4;
const float *weight_c = weights + c * w_stride;
const float *bias_ptr = ptr_zero;
if (bias != nullptr) {
bias_ptr = bias + c;
}
slidingwindow_fill_bias(pre_out, bias_ptr,
wout_round * hout_c_block * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const float *wc0 = weight_c;
const float *inr0 = block_inr0;
const float *inr1 = block_inr1;
const float *inr2 = block_inr2;
const float *inr3 = block_inr3;
const float *inr4 = block_inr4;
float *pre_out0 = pre_out + hk * out_row_stride;
float *pre_out1 = pre_out0 + out_row_stride;
#ifdef __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
float32x4_t w0 = vld1q_f32(wc0); // w0, v23
float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24
float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25
float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26
float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27
float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28
float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29
float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30
float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
const float *r4 = inr4;
int cnt = w_loop;
asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"2: \n" /* main loop*/
/* r0, r2, mul w0, get out r0, r1 */
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/
"fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/
"fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/
"fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/
"fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/
"fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/
"fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/
"fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/
"ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/
/* r2 mul w6, get out r0*/
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/
"fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/
"fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th
element*/
/* r0, r2, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
"fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/
"fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/
"fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/
"fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/
"fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/
"fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/
"fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/
"ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/
/* r2 mul w7, get out r0 */
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/
"fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/
"fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/
"ldr d13, [%[r3]] \n" /* load input r3, 9th
element*/
/* r0, r2, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/
"fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/
"fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 *
r0[8]*/
"fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/
"fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/
"fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/
"fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 *
r2[8]*/
"ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/
"fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/
"fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 *
r2[8]*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th
element*/
/* r1, r3, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/
"fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/
"fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/
"fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/
"fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/
"fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/
"fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/
/* r1, r3, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
"fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/
"fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/
"fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/
"fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/
"fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/
"fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/
"fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
/* r1, r3, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/
"fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/
"fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 *
r1[8]*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/
"fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/
"fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/
"fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 *
r3[8]*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
/* r4, mul w6, get out r1 */
"fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/
"fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/
"fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/
"fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
/* r4, mul w7, get out r1 */
"fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/
"fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/
"fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/
"fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
/* r4, mul w8, get out r1 */
"fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/
"fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/
"fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/
"fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 *
r4[8]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
"stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3),
[w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7),
[w8] "w"(w8)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22");
wc0 += 9 * hout_c_block;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
inr4 += win_round;
}
#else // not __aarch64__
for (int i = 0; i < chin; ++i) {
const float *wc0 = weight_c + i * w_stride_chin;
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
const float *r4 = inr4;
int cnt = w_loop;
asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
/* load r0, r2 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, "
"8 float\n"
"vld1.32 {d8}, [%[r0]] @ load r0, "
"9th float\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
/* main loop */
"0: @ main "
"loop\n"
/* mul r0, with w0, w1, w2 */
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load "
"outr1, w0, w1, c0~c3\n"
"vmla.f32 q8, q5, d0[0] @ w0 * "
"inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load "
"outr1, w2, w3, c0~c3\n"
"vmla.f32 q9, q5, d1[0] @ w0 * "
"inr02\n"
"vmla.f32 q10, q5, d2[0] @ w0 * "
"inr04\n"
"vmla.f32 q11, q5, d3[0] @ w0 * "
"inr06\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w1 * "
"inr01\n"
"vmla.f32 q9, q6, d1[1] @ w1 * "
"inr03\n"
"vmla.f32 q10, q6, d2[1] @ w1 * "
"inr05\n"
"vmla.f32 q11, q6, d3[1] @ w1 * "
"inr07\n"
"vld1.32 {d9}, [%[r2]] @ load r2, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w2 * "
"inr02\n"
"vmla.f32 q9, q7, d2[0] @ w2 * "
"inr04\n"
"vmla.f32 q10, q7, d3[0] @ w2 * "
"inr06\n"
"vmla.f32 q11, q7, d8[0] @ w2 * "
"inr08\n"
"sub %[r2], %[r2], #32 @ r2 - 32, "
"load r2 twice\n"
/* mul r2, with w0, w1, w2 */
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w0 * "
"inr20\n"
"vmla.f32 q13, q5, d5[0] @ w0 * "
"inr22\n"
"vmla.f32 q14, q5, d6[0] @ w0 * "
"inr24\n"
"vmla.f32 q15, q5, d7[0] @ w0 * "
"inr26\n"
"vld1.32 {d8}, [%[r1]] @ load r1, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w1 * "
"inr21\n"
"vmla.f32 q13, q6, d5[1] @ w1 * "
"inr23\n"
"vmla.f32 q14, q6, d6[1] @ w1 * "
"inr25\n"
"vmla.f32 q15, q6, d7[1] @ w1 * "
"inr27\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, "
"w4, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w2 * "
"inr22\n"
"vmla.f32 q13, q7, d6[0] @ w2 * "
"inr24\n"
"vmla.f32 q14, q7, d7[0] @ w2 * "
"inr26\n"
"vmla.f32 q15, q7, d9[0] @ w2 * "
"inr28\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, "
"to q7\n"
/* mul r1, with w3, w4, w5 */
"vmla.f32 q8, q5, d0[0] @ w3 * "
"inr10\n"
"vmla.f32 q9, q5, d1[0] @ w3 * "
"inr12\n"
"vmla.f32 q10, q5, d2[0] @ w3 * "
"inr14\n"
"vmla.f32 q11, q5, d3[0] @ w3 * "
"inr16\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w4 * "
"inr11\n"
"vmla.f32 q9, q6, d1[1] @ w4 * "
"inr13\n"
"vmla.f32 q10, q6, d2[1] @ w4 * "
"inr15\n"
"vmla.f32 q11, q6, d3[1] @ w4 * "
"inr17\n"
"vld1.32 {d9}, [%[r3]] @ load r3, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w5 * "
"inr12\n"
"vmla.f32 q9, q7, d2[0] @ w5 * "
"inr14\n"
"vmla.f32 q10, q7, d3[0] @ w5 * "
"inr16\n"
"vmla.f32 q11, q7, d8[0] @ w5 * "
"inr18\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 "
"- 32, to start address\n"
/* mul r3, with w3, w4, w5 */
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w3 * "
"inr30\n"
"vmla.f32 q13, q5, d5[0] @ w3 * "
"inr32\n"
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr34\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr36\n"
"vld1.32 {d8}, [%[r2]] @ load r2, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr31\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr33\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr35\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr37\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr32\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr34\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr36\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr38\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
/* mul r2, with w6, w7, w8 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q10, q5, d2[0] @ w6 * "
"inr24\n"
"vmla.f32 q11, q5, d3[0] @ w6 * "
"inr26\n"
"vld1.32 {d4-d7}, [%[r4]]! @ load r4, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q10, q6, d2[1] @ w7 * "
"inr25\n"
"vmla.f32 q11, q6, d3[1] @ w7 * "
"inr27\n"
"vld1.32 {d9}, [%[r4]] @ load r4, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q10, q7, d3[0] @ w8 * "
"inr26\n"
"vmla.f32 q11, q7, d8[0] @ w8 * "
"inr28\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
/* mul r4, with w6, w7, w8 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w3 * "
"inr40\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q13, q5, d5[0] @ w3 * "
"inr42\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr44\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr46\n"
"vld1.32 {d8}, [%[r0]] @ load r0, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr41\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr43\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr45\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr47\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr42\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr44\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr46\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr48\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1), [wc0] "+r"(wc0)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
inr4 += win_round;
}
#endif // __aarch64__
block_inr0 = block_inr4;
block_inr1 = block_inr0 + in_len;
block_inr2 = block_inr1 + in_len;
block_inr3 = block_inr2 + in_len;
block_inr4 = block_inr3 + in_len;
}
slidingwindow_writeout_c4_fp32(pre_out, dout_batch, c, c + hout_c_block,
h, h + h_kernel, 0, wout_round, chout,
hout, wout, relu, ptr_write);
}
#pragma omp parallel for
for (int c = 0; c < c_remain; ++c) {
#ifdef USE_OPENMP
float *pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
#else
float *pre_out = pre_din + pre_in_size;
#endif
const float *block_inr0 = cblock_inr0;
const float *block_inr1 = cblock_inr1;
const float *block_inr2 = cblock_inr2;
const float *block_inr3 = cblock_inr3;
const float *block_inr4 = cblock_inr4;
//! get weights ptr of remained
const float *weight_c = weights + c_round_down * w_stride;
//! fill bias to one channel
const float *bias_ptr = ptr_zero;
if (bias != nullptr) {
bias_ptr = bias + c_round_down + c;
}
slidingwindow_fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const float *wc0 = weight_c;
const float *inr0 = block_inr0;
const float *inr1 = block_inr1;
const float *inr2 = block_inr2;
const float *inr3 = block_inr3;
const float *inr4 = block_inr4;
float *pre_out0 = pre_out + hk * wout_round;
float *pre_out1 = pre_out0 + wout_round;
#ifdef __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
//! get valid weights of current output channel
float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23
float32x4_t w1 = vdupq_n_f32(wc0[c + 4]); // w1, v24
float32x4_t w2 = vdupq_n_f32(wc0[c + 8]); // w2, v25
float32x4_t w3 = vdupq_n_f32(wc0[c + 12]); // w3, v26
float32x4_t w4 = vdupq_n_f32(wc0[c + 16]); // w4, v27
float32x4_t w5 = vdupq_n_f32(wc0[c + 20]); // w5, v28
float32x4_t w6 = vdupq_n_f32(wc0[c + 24]); // w6, v29
float32x4_t w7 = vdupq_n_f32(wc0[c + 28]); // w7, v30
float32x4_t w8 = vdupq_n_f32(wc0[c + 32]); // w8, v31
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
const float *r4 = inr4;
int cnt = w_loop;
asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr00, outr01,
outr02, outr03*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"2: \n" /* main loop*/
/* r0, r2, mul w0, get out r0, r1 */
"ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11,
outr12, outr13*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2,
4, 6]*/
"fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2,
4, 6]*/
"ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/
/* r2 mul w6, get out r0*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2,
4, 6]*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th
element*/
/* shift left 1 */
"ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/
"ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/
/* r0, r2, mul w1, get out r0, r1 */
"fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3,
5, 7]*/
"fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3,
5, 7]*/
"ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/
/* r2 mul w7, get out r0 */
"fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1,
3, 5, 7]*/
"ldr d13, [%[r3]] \n" /* load input r3, 9th
element*/
/* r0, r2, mul w2, get out r0, r1 */
"fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4,
6, 8]*/
"fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4,
6, 8]*/
"ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */
"fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2,
4, 6, 8]*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th
element*/
/* r1, r3, mul w3, get out r0, r1 */
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2,
4, 6]*/
"fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2,
4, 6]*/
/* shift left 1 */
"ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/
"ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/
/* r1, r3, mul w4, get out r0, r1 */
"fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3,
5, 7]*/
"fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3,
5, 7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
/* r1, r3, mul w5, get out r0, r1 */
"fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/
"fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/
/* r4, mul w6, get out r1 */
"fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2,
4, 6]*/
"ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0*/
/* r4, mul w7, get out r1 */
"fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3,
5, 7]*/
/* r4, mul w8, get out r1 */
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4,
6, 8]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"str q22, [%[ptr_out1]], #16 \n" /* save outr1*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3),
[w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7),
[w8] "w"(w8)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v21", "v22");
wc0 += 36;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
inr4 += win_round;
}
#else // not __aarch64__
for (int i = 0; i < chin; ++i) {
float *ptr_out0 = pre_out0;
float *ptr_out1 = pre_out1;
//! get valid weights of current output channel
float w_tmp[12] = {wc0[c], wc0[c + 4], wc0[c + 8], 0.f,
wc0[c + 12], wc0[c + 16], wc0[c + 20], 0.f,
wc0[c + 24], wc0[c + 28], wc0[c + 32], 0.f};
float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0
float32x4_t w1 = vld1q_f32(w_tmp + 4); // w3, w4, w5, q1
float32x4_t w2 = vld1q_f32(w_tmp + 8); // w6, w7, w8, q2
const float *r0 = inr0;
const float *r1 = inr1;
const float *r2 = inr2;
const float *r3 = inr3;
const float *r4 = inr4;
int cnt = w_loop / 2;
if (cnt > 0) {
asm volatile(
/* main loop */
"0: @ "
"main loop\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, "
"or11\n"
"vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 "
"float, interleave\n"
"vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 "
"float, interleave\n"
"vld1.32 {d22}, [%[r2]] @ load 16th "
"float\n"
/* r2 * w2, r2 * w0, get or0, or1 */
"vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, "
"9, 11, 13, 15\n"
"vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 "
"float, interleave\n"
"vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, "
"9, 11, 13, 15\n"
"vext.32 q4, q3, q5, #1 @ r2, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r2, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, "
"8, 10, 12, 14\n"
"vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 "
"float, interleave\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, "
"8, 10, 12, 14\n"
"vld1.32 {d22}, [%[r0]] @ load 16th "
"float\n"
"vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, "
"2, 4, 6, 8\n"
"vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, "
"2, 4, 6, 8\n"
"vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 "
"float, interleave\n"
"vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, "
"10, 12, 14, 16\n"
"vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, "
"10, 12, 14, 16\n"
"vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 "
"float, interleave\n"
/* r0 * w0, get or0, r3 * w1, get or1*/
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, "
"9, 11, 13, 15\n"
"vext.32 q8, q7, q9, #1 @ r0, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q10, q9, q11, #1 @ r0, shift "
"left 1, get 10, 12, 14, 16\n"
"vld1.32 {d22}, [%[r3]] @ load 16th "
"float\n"
"vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, "
"9, 11, 13, 15\n"
"vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, "
"8, 10, 12, 14\n"
"vext.32 q4, q3, q5, #1 @ r3, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r3, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, "
"2, 4, 6, 8\n"
"vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, "
"10, 12, 14, 16\n"
"vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, "
"2, 4, 6, 8\n"
"vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, "
"10, 12, 14, 16\n"
"vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vld1.32 {d22}, [%[r1]] @ load 16th "
"float\n"
/* r1 * w1, get or0, r4 * w2, get or1 */
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, "
"9, 11, 13, 15\n"
"vext.32 q8, q7, q9, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q10, q9, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, "
"9, 11, 13, 15\n"
"vld1.32 {d22}, [%[r4]] @ load 16th "
"float\n"
"vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, "
"8, 10, 12, 14\n"
"vext.32 q4, q3, q5, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, "
"2, 4, 6, 8\n"
"vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, "
"10, 12, 14, 16\n"
"vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, "
"2, 4, 6, 8\n"
"vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, "
"10, 12, 14, 16\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1),
[r2] "+r"(r2), [r3] "+r"(r3), [r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc", "memory", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "q14", "q15");
}
//! deal with remain wout
if (w_loop & 1) {
ptr_out0[0] +=
r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] +
r1[0] * w_tmp[4] + r1[1] * w_tmp[5] + r1[2] * w_tmp[6] +
r2[0] * w_tmp[8] + r2[1] * w_tmp[9] + r2[2] * w_tmp[10];
ptr_out0[1] +=
r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] +
r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r1[4] * w_tmp[6] +
r2[2] * w_tmp[8] + r2[3] * w_tmp[9] + r2[4] * w_tmp[10];
ptr_out0[2] +=
r0[4] * w_tmp[0] + r0[5] * w_tmp[1] + r0[6] * w_tmp[2] +
r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r1[6] * w_tmp[6] +
r2[4] * w_tmp[8] + r2[5] * w_tmp[9] + r2[6] * w_tmp[10];
ptr_out0[3] +=
r0[6] * w_tmp[0] + r0[7] * w_tmp[1] + r0[8] * w_tmp[2] +
r1[6] * w_tmp[4] + r1[7] * w_tmp[5] + r1[8] * w_tmp[6] +
r2[6] * w_tmp[8] + r2[7] * w_tmp[9] + r2[8] * w_tmp[10];
ptr_out1[0] +=
r2[0] * w_tmp[0] + r2[1] * w_tmp[1] + r2[2] * w_tmp[2] +
r3[0] * w_tmp[4] + r3[1] * w_tmp[5] + r3[2] * w_tmp[6] +
r4[0] * w_tmp[8] + r4[1] * w_tmp[9] + r4[2] * w_tmp[10];
ptr_out1[1] +=
r2[2] * w_tmp[0] + r2[3] * w_tmp[1] + r2[4] * w_tmp[2] +
r3[2] * w_tmp[4] + r3[3] * w_tmp[5] + r3[4] * w_tmp[6] +
r4[2] * w_tmp[8] + r4[3] * w_tmp[9] + r4[4] * w_tmp[10];
ptr_out1[2] +=
r2[4] * w_tmp[0] + r2[5] * w_tmp[1] + r2[6] * w_tmp[2] +
r3[4] * w_tmp[4] + r3[5] * w_tmp[5] + r3[6] * w_tmp[6] +
r4[4] * w_tmp[8] + r4[5] * w_tmp[9] + r4[6] * w_tmp[10];
ptr_out1[3] +=
r2[6] * w_tmp[0] + r2[7] * w_tmp[1] + r2[8] * w_tmp[2] +
r3[6] * w_tmp[4] + r3[7] * w_tmp[5] + r3[8] * w_tmp[6] +
r4[6] * w_tmp[8] + r4[7] * w_tmp[9] + r4[8] * w_tmp[10];
}
wc0 += 36;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
inr4 += win_round;
}
#endif // __aarch64__
block_inr0 = block_inr4;
block_inr1 = block_inr0 + in_len;
block_inr2 = block_inr1 + in_len;
block_inr3 = block_inr2 + in_len;
block_inr4 = block_inr3 + in_len;
}
slidingwindow_writeout_c1_fp32(
pre_out, dout_batch, c + c_round_down, c + c_round_down + 1, h,
h + h_kernel, 0, wout_round, chout, hout, wout, relu, ptr_write);
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -33,6 +33,17 @@ void SlidingwindowConv3x3s2(const framework::Tensor *input,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void SlidingwindowConv3x3s1Faster(const framework::Tensor *input,
framework::Tensor *filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void SlidingwindowConv3x3s2Faster(const framework::Tensor *input,
framework::Tensor *filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/slidingwindow_utils.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void slidingwindow_fill_bias(float* dout, const float* bias, int ch_num,
int ch_size) {
for (int j = 0; j < ch_num; j++) {
float32x4_t vb = vdupq_n_f32(bias[j]);
int i = 0;
for (; i < ch_size - 3; i += 4) {
vst1q_f32(dout + i, vb);
}
for (; i < ch_size; i++) {
dout[i] = bias[j];
}
dout += ch_size;
}
}
/* write result in outputs
* input din: [n, c, h, w], output dout: [n, c, h, w]
*/
void slidingwindow_writeout_c1_fp32(const float* din, float* dout, int cs,
int ce, int hs, int he, int ws, int we,
int channel, int height, int width,
bool flag_relu, float* trash_ptr) {
if (cs > channel) {
return;
}
const int c1 = 1;
const int w4 = 4;
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int w_round = we - ws;
int cnt = (width - ws) / w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
const float* din_hei_ptr = ptr_din + i * w_round * c1;
if (cnt > 0) {
int cnt_loop = cnt;
if (flag_relu) {
#ifdef __aarch64__
asm volatile(
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop */
"fmax v1.4s, v0.4s, v20.4s \n" /* relu */
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
"str q1, [%[doutc0r0]], #16 \n" /* store c0r0 */
"bne 1b \n" /* jump to main loop */
: [doutc0r0] "+r"(doutc0_ptr), [cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v20");
#else
asm volatile(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, c1r0, "
"c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vmax.f32 q1, q0, q15 @ relu\n"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n"
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr), [ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q15");
#endif
} else {
#ifdef __aarch64__
asm volatile(
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"1: \n" /* main loop */
"str q0, [%[doutc0r0]], #16 \n" /* store c2r0 */
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"bne 1b \n" /* jump to main loop */
: [doutc0r0] "+r"(doutc0_ptr), [cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0");
#else
asm volatile(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, c0r1, "
"c0r2, c0r3\n"
"1: @ main loop\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr), [ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0");
#endif
}
}
if (we > width) {
int offset = i * w_round * c1 + c1 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
int j = we - w4;
if (flag_relu) {
for (; j < width; ++j) {
*(doutc0_ptr++) = std::max(din_hei_ptr[0], 0.f);
din_hei_ptr++;
}
} else {
for (; j < width; ++j) {
*(doutc0_ptr++) = *(din_hei_ptr++);
}
}
}
}
}
/* write result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
void slidingwindow_writeout_c4_fp32(const float* din, float* dout, int cs,
int ce, int hs, int he, int ws, int we,
int channel, int height, int width,
bool flag_relu, float* trash_ptr) {
const int c4 = 4;
const int w4 = 4;
const int w_round = we - ws;
const int ch_n = ce - cs;
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
float* doutc1r0 = doutc0r0 + size_c_out;
float* doutc2r0 = doutc1r0 + size_c_out;
float* doutc3r0 = doutc2r0 + size_c_out;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int cnt = (width - ws) / w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
const float* din_hei_ptr = ptr_din + i * w_round * ch_n;
if (cnt > 0) {
int cnt_loop = cnt;
if (flag_relu) {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop */
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1 */
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1 */
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3 */
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10 */
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10 */
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11 */
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11 */
"fmax v16.4s, v16.4s, v20.4s \n" /* relu */
"fmax v17.4s, v17.4s, v20.4s \n" /* relu */
"fmax v18.4s, v18.4s, v20.4s \n" /* relu */
"fmax v19.4s, v19.4s, v20.4s \n" /* relu */
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0 */
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0 */
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0 */
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0 */
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
"bne 1b \n" /* jump to main loop */
: [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop), [ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v16", "v17", "v18", "v19", "v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n"
"vmov.u32 q15, #0 @ dump zero \n"
"1: @ main loop \n"
"vtrn.32 q0, q1 @ trans data:c00c01c20c21 "
"\n"
"vtrn.32 q2, q3 @ trans data:c02c03c22c23 "
"\n"
"vswp d1, d4 @ swap data\n"
"vswp d3, d6 @ swap data\n"
"vmax.f32 q0, q0, q15 @ relu\n"
"vmax.f32 q1, q1, q15 @ relu\n"
"vmax.f32 q2, q2, q15 @ relu\n"
"vmax.f32 q3, q3, q15 @ relu\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr), [cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
} else {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"1: \n" /* main loop */
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1 */
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1 */
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3 */
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10 */
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10 */
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11 */
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11 */
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0 */
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0 */
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0 */
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0 */
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
"bne 1b \n" /* jump to main loop */
: [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop), [ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17",
"v18", "v19");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n"
"1: @ main loop \n"
"vtrn.32 q0, q1 @ trans data:c00c01c20c21 "
"\n"
"vtrn.32 q2, q3 @ trans data:c02c03c22c23 "
"\n"
"vswp d1, d4 @ swap data\n"
"vswp d3, d6 @ swap data\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr), [cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3");
#endif
}
}
if (we > width) {
int offset = i * w_round * c4 + c4 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
int j = we - w4;
if (flag_relu) {
for (; j < width; ++j) {
*(doutc0_ptr++) = std::max(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = std::max(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = std::max(din_hei_ptr[2], 0.f);
*(doutc3_ptr++) = std::max(din_hei_ptr[3], 0.f);
din_hei_ptr += w4;
}
} else {
for (; j < width; ++j) {
*(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2];
*(doutc3_ptr++) = din_hei_ptr[3];
din_hei_ptr += w4;
}
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "framework/tensor.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
/* preprocessing weights
* input weights: [chout, chin/ group, kh, kw] --> outputs weights: [chout / n,
* chin/ group, kh, kw, n]
*/
template <typename dtype>
void slidingwindow_transform_weight(const framework::Tensor& weight,
framework::Tensor* output) {
int chout = weight.dims()[0];
int chin = weight.dims()[1];
int kernel_size = weight.dims()[2] * weight.dims()[3];
const int n = 4;
int cround = (chout + n - 1) / n * n;
const dtype* din = weight.data<dtype>();
dtype* dout = output->mutable_data<dtype>({cround, chin, 3, 3});
int c_loop = chout / n;
int chout_round = (chout + n - 1) / n;
int win_stride = chin * kernel_size;
int wout_stride = n * win_stride;
int co = 0;
for (; co < c_loop; ++co) {
dtype* dout_c = dout + co * wout_stride;
const dtype* din_array[n];
din_array[0] = din + co * wout_stride;
for (int i = 1; i < n; i++) {
din_array[i] = din_array[i - 1] + win_stride;
}
for (int ci = 0; ci < chin; ++ci) {
for (int k = 0; k < kernel_size; ++k) {
for (int i = 0; i < n; i++) {
*(dout_c++) = *(din_array[i]++);
}
}
}
}
// pad final chout
if (chout_round > c_loop) {
dtype* dout_c = dout + c_loop * wout_stride;
const dtype* din_array[n];
din_array[0] = din + c_loop * wout_stride;
for (int i = 1; i < n; i++) {
din_array[i] = din_array[i - 1] + win_stride;
}
// deal remain
int cremain = chout_round * n - chout;
for (int i = 1; i <= cremain; i++) {
din_array[n - i] = din_array[0];
}
for (int ci = 0; ci < chin; ++ci) {
for (int k = 0; k < kernel_size; ++k) {
for (int i = 0; i < n; i++) {
*(dout_c++) = *(din_array[i]++);
}
}
}
}
}
/* preprocessing inputs
* input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws]
* n = he - hs
*/
template <typename dtype>
void slidingwindow_prepack_input(const dtype* din, dtype* dout, int cs, int ce,
int hs, int he, int ws, int we, int channel,
int width, int height, dtype* zero_ptr) {
int n = he - hs;
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int size_w = we - ws;
int size_wc_len = size_w * channel;
int size_c = width * height;
int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(dtype);
dtype* out_array[n];
out_array[0] = dout;
for (int i = 1; i < n; i++) {
out_array[i] = out_array[i - 1] + size_wc_len;
}
for (int c = 0; c < channel; ++c) {
int j = 0;
// valid height
for (int i = hs; i < he; i++) {
// get address
const dtype* in_array;
if (i < 0 || i >= height) {
in_array = zero_ptr;
} else {
in_array = din + i * width;
}
for (int w = ws; w < w0; ++w) {
*(out_array[j]++) = 0.f;
}
memcpy(out_array[j], in_array, valid_w_byte);
out_array[j] += valid_w;
for (int w = w1; w < we; ++w) {
*(out_array[j]++) = 0.f;
}
j++;
}
din += size_c;
}
}
inline void slidingwindow_fill_bias(float* dout, const float* bias, int size) {
float32x4_t vb = vld1q_f32(bias);
int cnt = size / 4;
for (int i = 0; i < cnt; ++i) {
vst1q_f32(dout, vb);
dout += 4;
}
}
void slidingwindow_fill_bias(float* dout, const float* bias, int ch_num,
int ch_size);
void slidingwindow_writeout_c1_fp32(const float* din, float* dout, int cs,
int ce, int hs, int he, int ws, int we,
int channel, int height, int width,
bool flag_relu, float* trash_ptr);
void slidingwindow_writeout_c4_fp32(const float* din, float* dout, int cs,
int ce, int hs, int he, int ws, int we,
int channel, int height, int width,
bool flag_relu, float* trash_ptr);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册