diff --git a/src/common/types.h b/src/common/types.h index 48e532d810b08b64e4143104918f0741771c1d75..b59a8df2dfb76219d44e44a48cbee1a0935be9b0 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -111,6 +111,26 @@ enum PoolingType { FIRST = 3, }; +enum PowerMode { + PERFORMANCE_PRIORITY = 0, // let threads run on big cores if + // thread_num <= big_cores_num, + // otherwise the power mode will be + // set to AUTO and all threads are + // scheduled by system + EFFICIENCY_PRIORITY = 1, // let threads run on little cores if + // thread_num <= little_cores_num, + // otherwise the power mode will be + // set to AUTO and all threads are + // scheduled by system + PERFORMANCE_ONLY = 2, // force threads run on big cores, + // and the remains are ignored if + // exceed the number big cores + EFFICIENCY_ONLY = 3, // force threads run on little cores, + // and the remains are ignored if + // exceed the number of little cores + AUTO = 4, // scheduled by system +}; + struct PaddleMobileConfigInternal { bool load_when_predict = false; }; diff --git a/src/framework/context.cpp b/src/framework/context.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c7319ba02cd8e9516201bff4ea12da3224d7b06e --- /dev/null +++ b/src/framework/context.cpp @@ -0,0 +1,523 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "framework/context.h" +#include +#include +#include "common/log.h" + +#ifdef __APPLE__ +#include "TargetConditionals.h" +#ifdef TARGET_OS_IPHONE +// iOS +#elif TARGET_OS_MAC +// Mac OS +#else +// Unsupported platform +#endif +#include +#include +#include +#else // Linux or Android +#include +#include +#endif + +namespace paddle_mobile { +namespace framework { + +const int DEFAULT_L1_CACHE_SIZE = 32 * 1024; +const int DEFAULT_L2_CACHE_SIZE = 2048 * 1024; +const int DEFAULT_L3_CACHE_SIZE = 0; + +void fill_cpu_cache_size(std::vector *cpu_cache_sizes, int value, + const std::vector cpu_ids = {}) { + int num = cpu_ids.size(); + if (num > 0) { + for (int i = 0; i < num; i++) { + (*cpu_cache_sizes)[cpu_ids[i]] = value; + } + } else { + num = cpu_cache_sizes->size(); + for (int i = 0; i < num; i++) { + (*cpu_cache_sizes)[i] = value; + } + } +} + +int get_cpu_num() { +#ifdef __APPLE__ + int count = 0; + size_t len = sizeof(count); + sysctlbyname("hw.ncpu", &count, &len, NULL, 0); + if (count < 1) { + count = 1; + } + return count; +#else // Linux or Android + // get cpu num from /sys/devices/system/cpu/cpunum/uevent + int max_cpu_num = 20; + int count = 0; + for (int i = 0; i < max_cpu_num; i++) { + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i); + FILE *fp = fopen(path, "rb"); + if (!fp) { + break; + } + count++; + fclose(fp); + } + if (count < 1) { + count = 1; + } + return count; +#endif +} + +#if !defined(__APPLE__) // Linux or Android +std::string get_cpu_name() { + FILE *fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return ""; + } + char line[1024]; + while (!feof(fp)) { + char *s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "Hardware") != NULL) { + fclose(fp); + return std::string(line); + } + } + fclose(fp); + return ""; +} + +int get_cpu_max_freq_khz(int cpu_id) { + // first try, for all possible cpu + char path[256]; +#ifdef __ANDROID__ + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpu_id); + FILE *fp = fopen(path, "rb"); + if (!fp) { + // second try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", + cpu_id); + fp = fopen(path, "rb"); + if (!fp) { + // third try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", + cpu_id); + fp = fopen(path, "rb"); + if (!fp) { + return 0; + } + int max_freq_khz = 0; + if (fscanf(fp, "%d", &max_freq_khz) <= 0) { + max_freq_khz = 0; + } + fclose(fp); + return max_freq_khz; + } + } + int max_freq_khz = 0; + while (!feof(fp)) { + int freq_khz = 0; + int nscan = fscanf(fp, "%d %*d", &freq_khz); + if (nscan != 1) { + break; + } + if (freq_khz > max_freq_khz) { + max_freq_khz = freq_khz; + } + } + fclose(fp); + return max_freq_khz; +#else + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/scaling_max_freq", cpu_id); + FILE *fp = fopen(path, "r"); + if (!fp) { + return 0; + } + int max_freq_khz = 0; + if (fscanf(fp, "%d", &max_freq_khz) <= 0) { + max_freq_khz = 0; + } + fclose(fp); + return max_freq_khz; +#endif +} + +void get_cpu_cache_size(int cpu_id, int *l1_cache_size, int *l2_cache_size, + int *l3_cache_size) { + int max_cache_idx_num = 10; + *l1_cache_size = DEFAULT_L1_CACHE_SIZE; + *l2_cache_size = DEFAULT_L2_CACHE_SIZE; + *l3_cache_size = DEFAULT_L3_CACHE_SIZE; + for (int i = 0; i < max_cache_idx_num; i++) { + char path[256]; + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cache/index%d/level", cpu_id, i); + FILE *fp = fopen(path, "rb"); + if (fp) { + int level = -1; + fscanf(fp, "%d", &level); + fclose(fp); + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cache/index%d/size", cpu_id, i); + fp = fopen(path, "rb"); + if (fp) { + int size = -1; + fscanf(fp, "%d", &size); + fclose(fp); + if (size >= 0) { + if (level == 1) { + *l1_cache_size = size * 1024; + } else if (level == 2) { + *l2_cache_size = size * 1024; + } else if (level == 3) { + *l3_cache_size = size * 1024; + } + } + } + } + } +} + +int check_online(std::vector *cpu_ids) { + if (cpu_ids->size() == 0) { + return 0; + } + std::vector online_cpu_ids; + char path[256]; + for (int i = 0; i < cpu_ids->size(); i++) { + int cpu_id = (*cpu_ids)[i]; + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online", + cpu_id); + FILE *fp = fopen(path, "rb"); + if (fp) { + int is_online = 0; + fscanf(fp, "%d", &is_online); + fclose(fp); + if (is_online != 0) { + online_cpu_ids.push_back(cpu_id); + } + } + // open failed(Permission denied) + } + *cpu_ids = online_cpu_ids; + return cpu_ids->size(); +} + +int set_sched_affinity(const std::vector &cpu_ids) { +// cpu_set_t definition +// ref http://stackoverflow.com/questions/16319725/android-set-thread-affinity +#define CPU_SETSIZE 1024 +#define __NCPUBITS (8 * sizeof(unsigned long)) + typedef struct { + unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; + } cpu_set_t; + +#define CPU_SET(cpu, cpusetp) \ + ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) + +#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) + + // set affinity for thread +#ifdef __GLIBC__ + pid_t pid = syscall(SYS_gettid); +#else + pid_t pid = gettid(); +#endif + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i = 0; i < cpu_ids.size(); i++) { + CPU_SET(cpu_ids[i], &mask); + } + int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); + if (syscallret) { + LOG(kLOG_WARNING) << "invoke syscall(__NR_sched_setaffinity) error(ret=" + << syscallret << ")"; + return -1; + } + return 0; +} + +int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, + std::vector *little_core_ids, + std::vector *l1_cache_sizes, + std::vector *l2_cache_sizes, + std::vector *l3_cache_sizes, + std::string hardware_name) { + /* Snapdragon */ + if (hardware_name.find("SDM845") != std::string::npos) { // 845 + *cpu_num = 8; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; + l1_cache_sizes->resize(*cpu_num); + l2_cache_sizes->resize(*cpu_num); + l3_cache_sizes->resize(*cpu_num); + fill_cpu_cache_size(l1_cache_sizes, 64 * 1024); + fill_cpu_cache_size(l2_cache_sizes, 256 * 1024, *big_core_ids); + fill_cpu_cache_size(l2_cache_sizes, 128 * 1024, *little_core_ids); + fill_cpu_cache_size(l3_cache_sizes, 2048 * 1024); + return 0; + } else if (hardware_name.find("SDM710") != std::string::npos) { // 710 + *cpu_num = 8; + *big_core_ids = {6, 7}; + *little_core_ids = {0, 1, 2, 3, 4, 5}; + l1_cache_sizes->resize(*cpu_num); + l2_cache_sizes->resize(*cpu_num); + l3_cache_sizes->resize(*cpu_num); + fill_cpu_cache_size(l1_cache_sizes, 64 * 1024, *big_core_ids); + fill_cpu_cache_size(l1_cache_sizes, 32 * 1024, *little_core_ids); + fill_cpu_cache_size(l2_cache_sizes, 256 * 1024, *big_core_ids); + fill_cpu_cache_size(l2_cache_sizes, 128 * 1024, *little_core_ids); + fill_cpu_cache_size(l3_cache_sizes, 1024 * 1024); + return 0; + } else if (hardware_name.find("MSM8998") != std::string::npos) { // 835 + *cpu_num = 8; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; + l1_cache_sizes->resize(*cpu_num); + l2_cache_sizes->resize(*cpu_num); + l3_cache_sizes->resize(*cpu_num); + fill_cpu_cache_size(l1_cache_sizes, 64 * 1024, *big_core_ids); + fill_cpu_cache_size(l1_cache_sizes, 32 * 1024, *little_core_ids); + // real L2 cache size is 2M, while that will get bad performace on conv3x3s1 + // or gemm, set to 1M or 512K + // fill_cpu_cache_size(l2_cache_sizes, 2048 *1024, + // *big_core_ids); + // fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024, + // *little_core_ids); + fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024); + fill_cpu_cache_size(l3_cache_sizes, 0); + return 0; + } else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653 + *cpu_num = 8; + *big_core_ids = {0, 1, 2, 3, 4, 5, 6, 7}; + *little_core_ids = {}; + l1_cache_sizes->resize(*cpu_num); + l2_cache_sizes->resize(*cpu_num); + l3_cache_sizes->resize(*cpu_num); + fill_cpu_cache_size(l1_cache_sizes, 32 * 1024); + fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024); + fill_cpu_cache_size(l3_cache_sizes, 0); + return 0; + } + return -1; +} + +// divide cpu cores into big and little clusters by max frequency +void get_cpu_info_by_probe(int cpu_num, std::vector *big_core_ids, + std::vector *little_core_ids, + std::vector *l1_cache_sizes, + std::vector *l2_cache_sizes, + std::vector *l3_cache_sizes) { + // get maxium & minium of cpu_max_freqs + std::vector cpu_max_freqs(cpu_num); + for (int i = 0; i < cpu_num; i++) { + cpu_max_freqs[i] = get_cpu_max_freq_khz(i) / 1000; + } + int max_cpu_max_freq = cpu_max_freqs[0]; + int min_cpu_max_freq = cpu_max_freqs[0]; + for (int i = 1; i < cpu_num; i++) { + int cur_cpu_max_freq = cpu_max_freqs[i]; + if (cur_cpu_max_freq < min_cpu_max_freq) { + min_cpu_max_freq = cur_cpu_max_freq; + } else if (cur_cpu_max_freq > max_cpu_max_freq) { + max_cpu_max_freq = cur_cpu_max_freq; + } + } + int mid_max_freq_khz = (max_cpu_max_freq + min_cpu_max_freq) / 2; + big_core_ids->clear(); + little_core_ids->clear(); + for (int i = 0; i < cpu_num; i++) { + if (cpu_max_freqs[i] >= mid_max_freq_khz) { + big_core_ids->push_back(i); + } else { + little_core_ids->push_back(i); + } + } + /* get l1, l2, l3 cache size for each core */ + l1_cache_sizes->resize(cpu_num); + l2_cache_sizes->resize(cpu_num); + l3_cache_sizes->resize(cpu_num); + for (int i = 0; i < cpu_num; i++) { + get_cpu_cache_size(i, &((*l1_cache_sizes)[i]), &((*l2_cache_sizes)[i]), + &((*l3_cache_sizes)[i])); + } +} + +void bind_threads(const std::vector &cpu_ids) { +#ifdef _OPENMP + int num_threads = omp_get_max_threads(); + std::vector ssarets; + for (int i = 0; i < num_threads; i++) { + ssarets.push_back(0); + } +#pragma omp parallel for + for (int i = 0; i < num_threads; i++) { + ssarets[i] = set_sched_affinity(cpu_ids); + } + for (int i = 0; i < num_threads; i++) { + if (ssarets[i] != 0) { + LOG(kLOG_WARNING) << "set cpu affinity failed, thread idx: " << i; + return; + } + } +#else + int ssaret = set_sched_affinity(cpu_ids); + if (ssaret != 0) { + LOG(kLOG_WARNING) << "set cpu affinity failed, thread idx: 0 "; + return; + } +#endif +} +#endif + +CPUContext::CPUContext() { + _cpu_num = get_cpu_num(); + _big_core_ids.clear(); + _little_core_ids.clear(); +#ifdef __APPLE__ + // set default L1, L2 and L3 cache sizes + _l1_cache_sizes.resize(_cpu_num); + _l2_cache_sizes.resize(_cpu_num); + _l3_cache_sizes.resize(_cpu_num); + fill_cpu_cache_size(&_l1_cache_sizes, DEFAULT_L1_CACHE_SIZE); + fill_cpu_cache_size(&_l2_cache_sizes, DEFAULT_L2_CACHE_SIZE); + fill_cpu_cache_size(&_l3_cache_sizes, DEFAULT_L3_CACHE_SIZE); +#else // Linux or Android + // probe cpu info, and set big&litte clusters, L1, L2 and L3 cache sizes + std::string cpu_name = get_cpu_name(); + bool failed = + get_cpu_info_by_name(&_cpu_num, &_big_core_ids, &_little_core_ids, + &_l1_cache_sizes, &_l2_cache_sizes, &_l3_cache_sizes, + cpu_name) != 0; + if (failed) { + get_cpu_info_by_probe(_cpu_num, &_big_core_ids, &_little_core_ids, + &_l1_cache_sizes, &_l2_cache_sizes, &_l3_cache_sizes); + } + LOG(kLOG_INFO) << "CPU num: " << _cpu_num; + for (int i = 0; i < _cpu_num; i++) { + LOG(kLOG_INFO) << i << " L1 Cache: " << _l1_cache_sizes[i] << "KB" + << " L2 Cache: " << _l2_cache_sizes[i] << "KB" + << " L3 Cache: " << _l3_cache_sizes[i] << "KB"; + } + LOG(kLOG_INFO) << "Big cores: "; + for (int i = 0; i < _big_core_ids.size(); i++) { + LOG(kLOG_INFO) << _big_core_ids[i]; + } + LOG(kLOG_INFO) << "Little cores: "; + for (int i = 0; i < _little_core_ids.size(); i++) { + LOG(kLOG_INFO) << _little_core_ids[i]; + } +#endif + // use single thread by default + set_thread_num(1, PERFORMANCE_PRIORITY); +} + +void CPUContext::set_thread_num(int thread_num, PowerMode power_mode) { + int big_core_num = _big_core_ids.size(); + int little_core_num = _little_core_ids.size(); +#ifdef _OPENMP + if (thread_num > _cpu_num) { + thread_num = _cpu_num; + } +#else + thread_num = 1; +#endif + std::vector bind_core_ids; + if (power_mode == PERFORMANCE_PRIORITY || power_mode == PERFORMANCE_ONLY) { + if (big_core_num > 0) { + bind_core_ids = _big_core_ids; + if (power_mode == PERFORMANCE_ONLY && thread_num > big_core_num) { + LOG(kLOG_ERROR) << "thread_num(" << thread_num + << ") exceed the big cores num (" << big_core_num << ")" + << ", force to set thread_num = " << big_core_num; + thread_num = big_core_num; + } + } + } else if (power_mode == EFFICIENCY_PRIORITY || + power_mode == EFFICIENCY_ONLY) { + if (little_core_num > 0) { + bind_core_ids = _little_core_ids; + if (power_mode == EFFICIENCY_ONLY && thread_num > little_core_num) { + LOG(kLOG_ERROR) << "thread_num(" << thread_num + << ") exceed the little cores num (" << little_core_num + << ")" + << ", force to set thread_num = " << little_core_num; + thread_num = little_core_num; + } + } + } + _power_mode = AUTO; +#ifdef _OPENMP + omp_set_num_threads(thread_num); + thread_num = omp_get_max_threads(); +#endif +#if !defined(__APPLE__) // Linux or Android + if (bind_core_ids.size() > 0 && check_online(&bind_core_ids) >= thread_num) { + bind_threads(bind_core_ids); + _power_mode = power_mode; + } +#endif + LOG(kLOG_INFO) << "thread num: " << thread_num + << " power mode: " << _power_mode; +} + +int CPUContext::get_thread_num() { + int thread_num = 1; +#ifdef _OPENMP + thread_num = omp_get_max_threads(); +#endif + return thread_num; +} + +int CPUContext::get_cache_size(int level) { + std::vector *ptr = nullptr; + if (level == 1) { + ptr = &_l1_cache_sizes; + } else if (level == 2) { + ptr = &_l2_cache_sizes; + } else if (level == 3) { + ptr = &_l3_cache_sizes; + } else { + return 0; + } + if (_power_mode == PERFORMANCE_PRIORITY || _power_mode == PERFORMANCE_ONLY) { + return (*ptr)[_big_core_ids[0]]; + } else if (_power_mode == EFFICIENCY_PRIORITY || + _power_mode == EFFICIENCY_ONLY) { + return (*ptr)[_little_core_ids[0]]; + } else { // AUTO + return (*ptr)[0]; + } +} + +void *CPUContext::get_work_space(int size_in_byte) { + return reinterpret_cast( + _workspace.mutable_data(make_ddim({size_in_byte}))); +} + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/context.h b/src/framework/context.h index 0f1d9bb7ada7e42766360735aeb260f076f5b6b7..4efab6c3a0427e18ee404b7a0ac1d158e26aaa7f 100644 --- a/src/framework/context.h +++ b/src/framework/context.h @@ -18,63 +18,45 @@ limitations under the License. */ #include #endif -#define MOBILE_MAX_CPU_NUM 8 +#include +#include "framework/tensor.h" namespace paddle_mobile { namespace framework { struct CPUContext { private: - CPUContext() : num_cpus(4), num_threads(1) { - // TODO(hjchen2) - for (int i = 0; i < num_cpus; ++i) { - cpu_frequencies[i] = 2400; // 2400 MHz - max_cpu_frequencies[i] = 2400; // 2400 MHz - } - // L1_cache = 32000; // 32K - L1_cache = 32 * 1024; - L2_cache = 2000000; // 2M - // L2_cache = 512000; - } - - public: - void set_num_threads(int threads) { -#if _ONENMP - omp_set_num_threads(threads); - if (threads <= omp_get_max_threads()) { - num_threads = threads; - } else { - num_threads = omp_get_max_threads(); - } -#endif - num_threads = (num_threads > 1) ? num_threads : 1; - } - + CPUContext(); virtual ~CPUContext() {} public: static CPUContext* Context() { - static CPUContext* ctx = new CPUContext; + static CPUContext* ctx = nullptr; + if (ctx == nullptr) { + ctx = new CPUContext(); + } return ctx; } - int num_cpus; - int num_threads; - int cpu_frequencies[MOBILE_MAX_CPU_NUM]; - int max_cpu_frequencies[MOBILE_MAX_CPU_NUM]; - - int L1_cache; - int L2_cache; + void set_thread_num(int thread_num, + PowerMode power_mode = PERFORMANCE_PRIORITY); + int get_thread_num(); + PowerMode get_power_mode() const { return _power_mode; } + int get_cache_size(int level); + int get_l1_cache_size() { return get_cache_size(1); } + int get_l2_cache_size() { return get_cache_size(2); } + int get_l3_cache_size() { return get_cache_size(3); } + void* get_work_space(int size_in_byte); + + int _cpu_num; + PowerMode _power_mode; + std::vector _big_core_ids; + std::vector _little_core_ids; + std::vector _l1_cache_sizes; + std::vector _l2_cache_sizes; + std::vector _l3_cache_sizes; + Tensor _workspace; }; -inline void set_global_num_threads(int threads) { - // CPUContext::Context()->set_num_threads(threads); - CPUContext::Context()->num_threads = threads; -} - -inline int get_global_num_threads() { - return CPUContext::Context()->num_threads; -} - } // namespace framework } // namespace paddle_mobile diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 2ac7035a70916c121ac65be5f280c2ae58326196..669ad42469fc9ca4e00a6d8ae11fe3d53b433ca9 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -40,8 +40,8 @@ namespace framework { #pragma mark - executor template -void Executor::SetThreadNum(int threads) { - set_global_num_threads(threads); +void Executor::SetThreadNum(int thread_num, PowerMode power_mode) { + CPUContext::Context()->set_thread_num(thread_num, power_mode); } template @@ -440,7 +440,7 @@ std::shared_ptr Executor::GetOutput( template PMStatus Executor::Predict() { #if _OPENMP - omp_set_num_threads(get_global_num_threads()); + omp_set_num_threads(CPUContext::Context()->get_thread_num()); #endif // clear all no persistable tensor array since write_to_array // is always push back a new tensor in the array diff --git a/src/framework/executor.h b/src/framework/executor.h index 57b05dbc6c5066ef61c7f0321706ed27b3051e2c..c2d096182d1a94317c4909a7a468f04148b79695 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include "common/types.h" @@ -37,7 +38,8 @@ class Executor { paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1, const bool use_optimize = true, const bool lod_mode = false); - void SetThreadNum(int threads); + void SetThreadNum(int thread_num, + PowerMode power_mode = PERFORMANCE_PRIORITY); PMStatus Predict(const std::vector> &inputs); PMStatus Predict( diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index bf9749393b154f5a1484a95852c2bad300037344..a961622a90eb6f8a43780ba2f147fcb718afda88 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -29,8 +29,9 @@ limitations under the License. */ namespace paddle_mobile { template -void PaddleMobile::SetThreadNum(int num) { - executor_->SetThreadNum(num); +void PaddleMobile::SetThreadNum(int thread_num, + PowerMode power_mode) { + executor_->SetThreadNum(thread_num, power_mode); } template diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index b05485fcae954e2aa2540ba81110fe36e6421019..2203c9cb5a1002823c998e5604b8c89908a0aae6 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -83,7 +83,8 @@ class PaddleMobile { bool quantification = false, int batch_size = 1, bool lod_mode = false); - void SetThreadNum(int count); + void SetThreadNum(int thread_num, + PowerMode power_mode = PERFORMANCE_PRIORITY); void Clear(); double GetPredictTime(); diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 7ae525be7efe1b23325e55c624a7db28506257fa..e403f51357e2c46bf1e59be92c54778a5abfa595 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/math/slidingwindow_utils.h" #include "operators/math/winograd/winograd_transform.h" namespace paddle_mobile { @@ -56,38 +57,31 @@ void InitBaseConvKernel(ConvParam *param) { } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 1 && param->Dilations()[0] == 1 -#if 1 - && (param->Input()->dims()[1] >= 8 && - param->Output()->dims()[1] >= 8) -#endif - ) { - param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; + param->Strides()[0] == 1 && param->Dilations()[0] == 1) { // transform weight Variable *transformed_var = param->GetScope()->Var(); param->transformed_filter_ = transformed_var->GetMutable(); - operators::math::winograd_transform_weight<8, 3>( - *param->Filter(), param->transformed_filter_); + if (param->Input()->dims()[1] >= 32 && param->Output()->dims()[1] >= 32 && + param->Output()->dims()[2] > 16 && param->Output()->dims()[3] > 16) { + math::winograd_transform_weight<8, 3>(*param->Filter(), + param->transformed_filter_); + param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; + } else { + math::slidingwindow_transform_weight(*param->Filter(), + param->transformed_filter_); + param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; + } } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 1 && param->Dilations()[0] == 1 -#if 1 - && (param->Input()->dims()[2] >= 48 && - param->Output()->dims()[1] <= 24) -#endif - ) { - param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; - } else if (conv3x3 && param->Groups() == 1 && - param->Strides()[0] == param->Strides()[1] && - param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 2 && param->Dilations()[0] == 1 -#if 1 - && (param->Input()->dims()[2] >= 48 && - param->Output()->dims()[1] <= 24) -#endif - ) { + param->Strides()[0] == 2 && param->Dilations()[0] == 1) { + // transform weight + Variable *transformed_var = param->GetScope()->Var(); + param->transformed_filter_ = + transformed_var->GetMutable(); + math::slidingwindow_transform_weight(*param->Filter(), + param->transformed_filter_); param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT; } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.cpp b/src/operators/kernel/central-arm-func/conv_arm_func.cpp index dd41df59f303dfce1a6b9eb598f6dd34d6b014d7..fbdf52df911a57534c5dbf297b351c6b3dc59e87 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.cpp +++ b/src/operators/kernel/central-arm-func/conv_arm_func.cpp @@ -243,9 +243,15 @@ void SlidingwindowConv3x3(const ConvParam ¶m) { output->mutable_data(); if (strides[0] == 1) { - math::SlidingwindowConv3x3s1(input, filter, paddings, output); + // math::SlidingwindowConv3x3s1(input, filter, paddings, + // output); + math::SlidingwindowConv3x3s1Faster( + input, param.transformed_filter_, paddings, output); } else if (strides[0] == 2) { - math::SlidingwindowConv3x3s2(input, filter, paddings, output); + // math::SlidingwindowConv3x3s2(input, filter, paddings, + // output); + math::SlidingwindowConv3x3s2Faster( + input, param.transformed_filter_, paddings, output); } else { GemmConv(param); } diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index 1a536cba4e7ce6a52ba409856af9151152bf87eb..976415b9ac1e3d0761ae11588c27ee9b99156d1f 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -29,8 +29,6 @@ namespace paddle_mobile { namespace operators { namespace math { -static framework::CPUContext *g_cpu_ctx = framework::CPUContext::Context(); - int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, const int N, const int K) { @@ -70,11 +68,15 @@ class GemmExecutor : public Executor { unsigned int L1_size = 0; unsigned int L2_size = 0; if (M_ > N_) { - L2_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, M_, K_); - L1_size = g_cpu_ctx->L2_cache; + L2_size = + ResetL1Cache(framework::CPUContext::Context()->get_l1_cache_size(), + num_threads_, M_, K_); + L1_size = framework::CPUContext::Context()->get_l2_cache_size(); } else { - L1_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, N_, K_); - L2_size = g_cpu_ctx->L2_cache; + L1_size = + ResetL1Cache(framework::CPUContext::Context()->get_l1_cache_size(), + num_threads_, N_, K_); + L2_size = framework::CPUContext::Context()->get_l2_cache_size(); } rhs_tile_num_ = L1_size / (K_ * sizeof(Itype)); diff --git a/src/operators/math/slidingwindow_conv3x3.cpp b/src/operators/math/slidingwindow_conv3x3.cpp index 76a79c07740d435e545121378a8c5739c76517c6..0452a290275d72acc80e193c41c3bb0e3ffc5ff0 100644 --- a/src/operators/math/slidingwindow_conv3x3.cpp +++ b/src/operators/math/slidingwindow_conv3x3.cpp @@ -14,6 +14,8 @@ limitations under the License. */ #include "operators/math/slidingwindow_conv3x3.h" #include +#include "framework/context.h" +#include "operators/math/slidingwindow_utils.h" #if __ARM_NEON #include #endif @@ -703,7 +705,7 @@ void SlidingwindowConv3x3s1(const framework::Tensor *input, in_ptr3--; in_ptr4--; } -#endif //__aarch64__ +#endif // __aarch64__ #endif // __ARM_NEON // remain output_width @@ -1250,7 +1252,7 @@ void SlidingwindowConv3x3s1(const framework::Tensor *input, } } -#endif //__aarch64__ +#endif // __aarch64__ #endif // __ARM_NEON // remain output_width @@ -1738,7 +1740,7 @@ void SlidingwindowConv3x3s1(const framework::Tensor *input, in_ptr3--; in_ptr4--; } -#endif //__aarch64__ +#endif // __aarch64__ #endif // __ARM_NEON // remain output_width @@ -2940,7 +2942,7 @@ void SlidingwindowConv3x3s2(const framework::Tensor *input, in_ptr3 = in_ptr2 + input_w; } } -#endif //__aarch64__ +#endif // __aarch64__ #endif // __ARM_NEON // remain output_width @@ -3594,7 +3596,7 @@ void SlidingwindowConv3x3s2(const framework::Tensor *input, "q7", "q8", "q10", "q12", "q13", "q14", "q15"); } } -#endif //__aarch64__ +#endif // __aarch64__ #endif // __ARM_NEON out_ptr1 -= 4; out_ptr1 += 4; @@ -3705,6 +3707,1956 @@ void SlidingwindowConv3x3s2(const framework::Tensor *input, } } +template <> +void SlidingwindowConv3x3s1Faster( + const framework::Tensor *input, framework::Tensor *filter, + const std::vector &paddings, framework::Tensor *output) { + const float *din = input->data(); + float *dout = output->mutable_data(); + const float *weights = filter->mutable_data(); + const float *bias = nullptr; + bool relu = false; + const int num = input->dims()[0]; + const int chin = input->dims()[1]; + const int hin = input->dims()[2]; + const int win = input->dims()[3]; + const int chout = output->dims()[1]; + const int hout = output->dims()[2]; + const int wout = output->dims()[3]; + const int pad_h = paddings[0]; + const int pad_w = paddings[1]; + const int threads = framework::CPUContext::Context()->get_thread_num(); + int l2_size = + framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float); + + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 2; + + int hout_r_block = (l2_size - 2 * win_round * chin) / + (win_round * chin + hout_c_block * wout_round * threads); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + float ptr_zero[win_round]; + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float *pre_din = + static_cast(framework::CPUContext::Context()->get_work_space( + (pre_in_size + threads * pre_out_size) * sizeof(float))); + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; // kernel_w * kernel_h; + int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = chout - (chout / hout_c_block) * hout_c_block; + int c_round_down = (chout / hout_c_block) * hout_c_block; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * chin * size_in_channel; + float *dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - pad_h; + int he = hs + h_kernel + 2; + slidingwindow_prepack_input(din_batch, pre_din, 0, chin, hs, he, ws, we, + chin, win, hin, ptr_zero); +#pragma omp parallel for + for (int c = 0; c < chout - (hout_c_block - 1); c += hout_c_block) { +#ifdef _OPENMP + float *pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float *pre_out = pre_din + pre_in_size; +#endif + const float *block_inr0 = pre_din; + const float *block_inr1 = block_inr0 + in_len; + const float *block_inr2 = block_inr1 + in_len; + const float *block_inr3 = block_inr2 + in_len; + + const float *weight_c = weights + c * w_stride; + const float *bias_ptr = ptr_zero; + if (bias != nullptr) { + bias_ptr = bias + c; + } + slidingwindow_fill_bias(pre_out, bias_ptr, + wout_round * hout_c_block * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float *wc0 = weight_c; + + const float *inr0 = block_inr0; + const float *inr1 = block_inr1; + const float *inr2 = block_inr2; + const float *inr3 = block_inr3; + + float *pre_out0 = pre_out + hk * out_row_stride; + float *pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + + int cnt = w_loop; + asm volatile( + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "2: \n" /* main loop*/ + /* r0, r1, mul w0, get out r0, r1 */ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ + "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ + "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ + "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ + "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ + "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ + "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ + + /* r0, r1, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ + "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ + "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ + "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ + "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ + "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ + "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ + + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + + /* r0, r1, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ + "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ + "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ + "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ + "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ + "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ + "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ + + /* r1, r2, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ + "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ + "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ + "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ + "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ + "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ + "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ + + "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ + + /* r1, r2, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ + "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ + "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ + "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ + "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ + "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ + "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ + + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + + /* r1, r2, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ + "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ + "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ + "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ + "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ + "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ + "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ + + /* r2, r3, mul w6, get out r0, r1 */ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ + "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ + "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ + "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ + "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ + "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ + "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ + + "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ + + /* r2, r3, mul w7, get out r0, r1 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ + "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ + "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ + "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ + "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ + "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ + "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + /* r2, r3, mul w8, get out r0, r1 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ + "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ + "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ + + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3), + [w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7), + [w8] "w"(w8) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < chin; ++i) { + const float *wc0 = weight_c + i * w_stride_chin; + + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + + int cnt = w_loop; + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + /* load r0, r1 */ + "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " + "4 float\n" + "vld1.32 {d2}, [%[r0]] @ load r0, " + "2 float\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + /* main loop */ + "0: @ main " + "loop\n" + /* mul r0 with w0, w1, w2, get out r0 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " + "outr1, w0, w1, c0~c3\n" + "vmla.f32 q8, q5, d0[0] @ w0 * " + "inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " + "outr1, w2, w3, c0~c3\n" + "vmla.f32 q9, q5, d0[1] @ w0 * " + "inr01\n" + "vmla.f32 q10, q5, d1[0] @ w0 * " + "inr02\n" + "vmla.f32 q11, q5, d1[1] @ w0 * " + "inr03\n" + "vld1.32 {d3-d4}, [%[r1]]! @ load r1, " + "4 float\n" + "vmla.f32 q8, q6, d0[1] @ w1 * " + "inr01\n" + "vmla.f32 q9, q6, d1[0] @ w1 * " + "inr02\n" + "vmla.f32 q10, q6, d1[1] @ w1 * " + "inr03\n" + "vmla.f32 q11, q6, d2[0] @ w1 * " + "inr04\n" + "vld1.32 {d5}, [%[r1]] @ load r0, " + "2 float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * " + "inr02\n" + "vmla.f32 q9, q7, d1[1] @ w2 * " + "inr03\n" + "vmla.f32 q10, q7, d2[0] @ w2 * " + "inr04\n" + "vmla.f32 q11, q7, d2[1] @ w2 * " + "inr05\n" + + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " + "- 32, to start address\n" + + /* mul r1 with w0, w1, w2, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w0 * " + "inr10\n" + "vmla.f32 q13, q5, d3[1] @ w0 * " + "inr11\n" + "vmla.f32 q14, q5, d4[0] @ w0 * " + "inr12\n" + "vmla.f32 q15, q5, d4[1] @ w0 * " + "inr13\n" + "vmla.f32 q12, q6, d3[1] @ w1 * " + "inr11\n" + "vmla.f32 q13, q6, d4[0] @ w1 * " + "inr12\n" + "vmla.f32 q14, q6, d4[1] @ w1 * " + "inr13\n" + "vmla.f32 q15, q6, d5[0] @ w1 * " + "inr14\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " + "w4, to q5, q6\n" + "vmla.f32 q12, q7, d4[0] @ w2 * " + "inr12\n" + "vmla.f32 q13, q7, d4[1] @ w2 * " + "inr13\n" + "vmla.f32 q14, q7, d5[0] @ w2 * " + "inr14\n" + "vmla.f32 q15, q7, d5[1] @ w2 * " + "inr15\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " + "to q7\n" + + /* mul r1 with w3, w4, w5, get out r0 */ + "vmla.f32 q8, q5, d3[0] @ w3 * " + "inr10\n" + "vmla.f32 q9, q5, d3[1] @ w3 * " + "inr11\n" + "vmla.f32 q10, q5, d4[0] @ w3 * " + "inr12\n" + "vmla.f32 q11, q5, d4[1] @ w3 * " + "inr13\n" + "vld1.32 {d0-d1}, [%[r2]]! @ load r2, " + "4 float\n" + "vmla.f32 q8, q6, d3[1] @ w4 * " + "inr11\n" + "vmla.f32 q9, q6, d4[0] @ w4 * " + "inr12\n" + "vmla.f32 q10, q6, d4[1] @ w4 * " + "inr13\n" + "vmla.f32 q11, q6, d5[0] @ w4 * " + "inr14\n" + "vld1.32 {d2}, [%[r2]] @ load r2, " + "2 float\n" + "vmla.f32 q8, q7, d4[0] @ w5 * " + "inr12\n" + "vmla.f32 q9, q7, d4[1] @ w5 * " + "inr13\n" + "vmla.f32 q10, q7, d5[0] @ w5 * " + "inr14\n" + "vmla.f32 q11, q7, d5[1] @ w5 * " + "inr15\n" + + /* mul r2 with w3, w4, w5, get out r1 */ + "vmla.f32 q12, q5, d0[0] @ w3 * " + "inr20\n" + "vmla.f32 q13, q5, d0[1] @ w3 * " + "inr21\n" + "vmla.f32 q14, q5, d1[0] @ w3 * " + "inr22\n" + "vmla.f32 q15, q5, d1[1] @ w3 * " + "inr23\n" + "vmla.f32 q12, q6, d0[1] @ w4 * " + "inr21\n" + "vmla.f32 q13, q6, d1[0] @ w4 * " + "inr22\n" + "vmla.f32 q14, q6, d1[1] @ w4 * " + "inr23\n" + "vmla.f32 q15, q6, d2[0] @ w4 * " + "inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " + "w7, to q5, q6\n" + "vmla.f32 q12, q7, d1[0] @ w5 * " + "inr22\n" + "vmla.f32 q13, q7, d1[1] @ w5 * " + "inr23\n" + "vmla.f32 q14, q7, d2[0] @ w5 * " + "inr24\n" + "vmla.f32 q15, q7, d2[1] @ w5 * " + "inr25\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " + "to q7\n" + + "sub %[wc0], %[wc0], #144 @ wc0 - " + "144 to start address\n" + + /* mul r2 with w6, w7, w8, get out r0 */ + "vmla.f32 q8, q5, d0[0] @ w6 * " + "inr20\n" + "vmla.f32 q9, q5, d0[1] @ w6 * " + "inr21\n" + "vld1.32 {d3-d4}, [%[r3]]! @ load r3, " + "4 float\n" + "vmla.f32 q10, q5, d1[0] @ w6 * " + "inr22\n" + "vmla.f32 q11, q5, d1[1] @ w6 * " + "inr23\n" + "vmla.f32 q8, q6, d0[1] @ w7 * " + "inr21\n" + "vmla.f32 q9, q6, d1[0] @ w7 * " + "inr22\n" + "vld1.32 {d5}, [%[r3]] @ load r3, " + "2 float\n" + "vmla.f32 q10, q6, d1[1] @ w7 * " + "inr23\n" + "vmla.f32 q11, q6, d2[0] @ w7 * " + "inr24\n" + "vmla.f32 q8, q7, d1[0] @ w8 * " + "inr22\n" + "vmla.f32 q9, q7, d1[1] @ w8 * " + "inr23\n" + "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " + "4 float\n" + "vmla.f32 q10, q7, d2[0] @ w8 * " + "inr24\n" + "vmla.f32 q11, q7, d2[1] @ w8 * " + "inr25\n" + "vld1.32 {d2}, [%[r0]] @ load r0, " + "2 float\n" + + /* mul r3 with w6, w7, w8, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w6 * " + "inr20\n" + "vmla.f32 q13, q5, d3[1] @ w6 * " + "inr21\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " + "r00, r01, c0~c3\n" + "vmla.f32 q14, q5, d4[0] @ w6 * " + "inr22\n" + "vmla.f32 q15, q5, d4[1] @ w6 * " + "inr23\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " + "r02, r03, c0~c3\n" + "vmla.f32 q12, q6, d3[1] @ w7 * " + "inr21\n" + "vmla.f32 q13, q6, d4[0] @ w7 * " + "inr22\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vmla.f32 q14, q6, d4[1] @ w7 * " + "inr23\n" + "vmla.f32 q15, q6, d5[0] @ w7 * " + "inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vmla.f32 q12, q7, d4[0] @ w8 * " + "inr22\n" + "vmla.f32 q13, q7, d4[1] @ w8 * " + "inr23\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + "vmla.f32 q14, q7, d5[0] @ w8 * " + "inr24\n" + "vmla.f32 q15, q7, d5[1] @ w8 * " + "inr25\n" + + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " + "r10, r11, c0~c3\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " + "r12, r13, c0~c3\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + "subs %[cnt], #1 @ loop " + "count--\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), [wc0] "+r"(wc0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + slidingwindow_writeout_c4_fp32(pre_out, dout_batch, c, c + hout_c_block, + h, h + h_kernel, 0, wout_round, chout, + hout, wout, relu, ptr_write); + } + const float *weight_remain_ptr = weights + c_round_down * w_stride; +#pragma omp parallel for + for (int c = 0; c < c_remain; ++c) { +#ifdef USE_OPENMP + float *pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float *pre_out = pre_din + pre_in_size; +#endif + + int c_idx = c_round_down + c; + + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + + const float *block_inr0 = pre_din; + const float *block_inr1 = block_inr0 + in_len; + const float *block_inr2 = block_inr1 + in_len; + const float *block_inr3 = block_inr2 + in_len; + + const float *bias_ptr = ptr_zero; + if (bias != nullptr) { + bias_ptr = bias + c_idx; + } + slidingwindow_fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float *wc0 = weight_remain_ptr; + + const float *inr0 = block_inr0; + const float *inr1 = block_inr1; + const float *inr2 = block_inr2; + const float *inr3 = block_inr3; + + float *pre_out0 = pre_out + hk * wout_round; + float *pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + + int cnt = w_loop; + asm volatile( + "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "2: \n" /* main loop*/ + + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ + + "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ + "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ + "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ + "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ + + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + + "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ + "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ + + "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ + "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ + + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ + + "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ + "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ + "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ + "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ + + "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ + + "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ + "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ + + "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ + + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ + "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ + + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + + "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ + "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ + + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + + "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ + + "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ + "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3), + [w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7), + [w8] "w"(w8) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v21", "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[10] = { + wc0[c], wc0[c + 4], wc0[c + 8], wc0[c + 12], wc0[c + 16], + wc0[c + 20], wc0[c + 24], wc0[c + 28], wc0[c + 32], 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + int cnt = w_loop / 2; + if (cnt > 0) { + asm volatile( + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " + "or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 " + "float\n" + "vld1.32 {d10}, [%[r0]] @ load r0, 2 " + "float\n" + /* main loop */ + "0: @ main loop\n" + /* r0 * w0, w1, w2, get out r0*/ + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " + "or11\n" + "vext.32 q8, q3, q4, #1 @ r0, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r0, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r0, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r0, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r1]] @ load r1, 2 " + "float\n" + + /* r1 * w3, w4, w5, get out r0*/ + /* r1 * w0, w1, w2, get out r1*/ + "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, " + "4, 5, 6, 7\n" + "vext.32 q8, q3, q4, #1 @ r1, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r1, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r1, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r1, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, " + "5, 6, 7, 8\n" + "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, " + "1, 2, 3, 4\n" + "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, " + "6, 7, 8, 9\n" + "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r2]] @ load r2, 2 " + "float\n" + + /* r2 * w6, w7, w8, get out r0*/ + /* r2 * w3, w4, w5, get out r1*/ + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, " + "4, 5, 6, 7\n" + "vext.32 q8, q3, q4, #1 @ r2, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r2, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r2, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r2, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, " + "5, 6, 7, 8\n" + "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, " + "1, 2, 3, 4\n" + "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, " + "6, 7, 8, 9\n" + "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r3]] @ load r3, 2 " + "float\n" + + /* r3 * w6, w7, w8, get out r1*/ + "vext.32 q8, q3, q4, #1 @ r3, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r3, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, " + "4, 5, 6, 7\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, " + "or01\n" + "vext.32 q10, q3, q4, #2 @ r3, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r3, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, " + "4, 5, 6, 7\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " + "or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 " + "float\n" + "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r0]] @ load r0, 2 " + "float\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, " + "or11\n" + + "subs %[cnt], #1 @loop count " + "-1\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), + [r2] "+r"(r2), [r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc", "memory", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15"); + r0 -= 8; + } + //! deal with remain wout + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] + + r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8]; + + ptr_out0[1] += + r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] + + r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + + r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8]; + + ptr_out0[2] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] + + r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8]; + + ptr_out0[3] += + r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] + + r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + + r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8]; + + ptr_out1[0] += + r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] + + r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] + + r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8]; + + ptr_out1[1] += + r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] + + r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] + + r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8]; + + ptr_out1[2] += + r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] + + r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] + + r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8]; + + ptr_out1[3] += + r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] + + r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] + + r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + slidingwindow_writeout_c1_fp32(pre_out, dout_batch, c_idx, c_idx + 1, h, + h + h_kernel, 0, wout_round, chout, hout, + wout, relu, ptr_write); + } + } + } +} + +template <> +void SlidingwindowConv3x3s2Faster( + const framework::Tensor *input, framework::Tensor *filter, + const std::vector &paddings, framework::Tensor *output) { + const float *din = input->data(); + float *dout = output->mutable_data(); + const float *weights = filter->mutable_data(); + const float *bias = nullptr; + bool relu = false; + const int num = input->dims()[0]; + const int chin = input->dims()[1]; + const int hin = input->dims()[2]; + const int win = input->dims()[3]; + const int chout = output->dims()[1]; + const int hout = output->dims()[2]; + const int wout = output->dims()[3]; + const int pad_h = paddings[0]; + const int pad_w = paddings[1]; + const int threads = framework::CPUContext::Context()->get_thread_num(); + int l2_size = + framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float); + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round * 2 /*stride_w*/ + 1; + //! get h block + //! win_round * chin * hin_r_block + wout_round * hout_c_block * hout_r_block + //! * threads = l2_size win_round = 2 * wout_round + 1 hin_r_block = 2 * + //! hout_r_block + 1 + int hout_r_block = + (l2_size - 2 * wout_round * chin - chin) / + ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; + + float ptr_zero[win_round]; + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float *pre_din = + static_cast(framework::CPUContext::Context()->get_work_space( + (pre_in_size + threads * pre_out_size) * sizeof(float))); + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; /*kernel_w * kernel_h*/ + int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = chout - (chout / hout_c_block) * hout_c_block; + int c_round_down = (chout / hout_c_block) * hout_c_block; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * chin * size_in_channel; + float *dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + + int hs = h * 2 /*stride_h*/ - pad_h; + int he = hs + h_kernel * 2 /*stride_h*/ + 1; + + slidingwindow_prepack_input(din_batch, pre_din, 0, chin, hs, he, ws, we, + chin, win, hin, ptr_zero); + + const float *cblock_inr0 = pre_din; + const float *cblock_inr1 = cblock_inr0 + in_len; + const float *cblock_inr2 = cblock_inr1 + in_len; + const float *cblock_inr3 = cblock_inr2 + in_len; + const float *cblock_inr4 = cblock_inr3 + in_len; + +#pragma omp parallel for + for (int c = 0; c < c_round_down; c += hout_c_block) { +#ifdef _OPENMP + float *pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float *pre_out = pre_din + pre_in_size; +#endif + const float *block_inr0 = cblock_inr0; + const float *block_inr1 = cblock_inr1; + const float *block_inr2 = cblock_inr2; + const float *block_inr3 = cblock_inr3; + const float *block_inr4 = cblock_inr4; + + const float *weight_c = weights + c * w_stride; + const float *bias_ptr = ptr_zero; + if (bias != nullptr) { + bias_ptr = bias + c; + } + slidingwindow_fill_bias(pre_out, bias_ptr, + wout_round * hout_c_block * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float *wc0 = weight_c; + + const float *inr0 = block_inr0; + const float *inr1 = block_inr1; + const float *inr2 = block_inr2; + const float *inr3 = block_inr3; + const float *inr4 = block_inr4; + + float *pre_out0 = pre_out + hk * out_row_stride; + float *pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + const float *r4 = inr4; + + int cnt = w_loop; + asm volatile( + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "2: \n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ + "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ + "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ + "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ + "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ + "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ + "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ + + "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ + + /* r2 mul w6, get out r0*/ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ + "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ + "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ + + "ldr d11, [%[r1]] \n" /* load input r1, 9th + element*/ + + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ + "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ + "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ + "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ + "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ + "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ + "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ + + "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ + + /* r2 mul w7, get out r0 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ + "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ + "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ + + "ldr d13, [%[r3]] \n" /* load input r3, 9th + element*/ + + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ + "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ + "fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 * + r0[8]*/ + "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ + "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ + "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ + "fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 * + r2[8]*/ + + "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ + + /* r2, mul w8, get out r0 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ + "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ + "fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 * + r2[8]*/ + + "ldr d14, [%[r4]] \n" /* load input r4, 9th + element*/ + + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ + "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ + "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ + "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ + "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ + "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ + "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ + + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ + "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ + "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ + "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ + "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ + "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ + "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ + + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ + "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ + "fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 * + r1[8]*/ + + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + + "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ + "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ + "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ + "fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 * + r3[8]*/ + + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + + /* r4, mul w6, get out r1 */ + "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ + "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ + "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ + "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ + + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + + /* r4, mul w7, get out r1 */ + "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ + "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ + "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ + "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ + + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + + /* r4, mul w8, get out r1 */ + "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ + "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ + "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ + "fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 * + r4[8]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3), + [w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7), + [w8] "w"(w8) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < chin; ++i) { + const float *wc0 = weight_c + i * w_stride_chin; + + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + const float *r4 = inr4; + + int cnt = w_loop; + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + /* load r0, r2 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " + "8 float\n" + "vld1.32 {d8}, [%[r0]] @ load r0, " + "9th float\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + /* main loop */ + "0: @ main " + "loop\n" + /* mul r0, with w0, w1, w2 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " + "outr1, w0, w1, c0~c3\n" + "vmla.f32 q8, q5, d0[0] @ w0 * " + "inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " + "outr1, w2, w3, c0~c3\n" + "vmla.f32 q9, q5, d1[0] @ w0 * " + "inr02\n" + "vmla.f32 q10, q5, d2[0] @ w0 * " + "inr04\n" + "vmla.f32 q11, q5, d3[0] @ w0 * " + "inr06\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w1 * " + "inr01\n" + "vmla.f32 q9, q6, d1[1] @ w1 * " + "inr03\n" + "vmla.f32 q10, q6, d2[1] @ w1 * " + "inr05\n" + "vmla.f32 q11, q6, d3[1] @ w1 * " + "inr07\n" + "vld1.32 {d9}, [%[r2]] @ load r2, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * " + "inr02\n" + "vmla.f32 q9, q7, d2[0] @ w2 * " + "inr04\n" + "vmla.f32 q10, q7, d3[0] @ w2 * " + "inr06\n" + "vmla.f32 q11, q7, d8[0] @ w2 * " + "inr08\n" + + "sub %[r2], %[r2], #32 @ r2 - 32, " + "load r2 twice\n" + + /* mul r2, with w0, w1, w2 */ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w0 * " + "inr20\n" + "vmla.f32 q13, q5, d5[0] @ w0 * " + "inr22\n" + "vmla.f32 q14, q5, d6[0] @ w0 * " + "inr24\n" + "vmla.f32 q15, q5, d7[0] @ w0 * " + "inr26\n" + "vld1.32 {d8}, [%[r1]] @ load r1, " + "9th float\n" + "vmla.f32 q12, q6, d4[1] @ w1 * " + "inr21\n" + "vmla.f32 q13, q6, d5[1] @ w1 * " + "inr23\n" + "vmla.f32 q14, q6, d6[1] @ w1 * " + "inr25\n" + "vmla.f32 q15, q6, d7[1] @ w1 * " + "inr27\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " + "w4, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w2 * " + "inr22\n" + "vmla.f32 q13, q7, d6[0] @ w2 * " + "inr24\n" + "vmla.f32 q14, q7, d7[0] @ w2 * " + "inr26\n" + "vmla.f32 q15, q7, d9[0] @ w2 * " + "inr28\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " + "to q7\n" + + /* mul r1, with w3, w4, w5 */ + "vmla.f32 q8, q5, d0[0] @ w3 * " + "inr10\n" + "vmla.f32 q9, q5, d1[0] @ w3 * " + "inr12\n" + "vmla.f32 q10, q5, d2[0] @ w3 * " + "inr14\n" + "vmla.f32 q11, q5, d3[0] @ w3 * " + "inr16\n" + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w4 * " + "inr11\n" + "vmla.f32 q9, q6, d1[1] @ w4 * " + "inr13\n" + "vmla.f32 q10, q6, d2[1] @ w4 * " + "inr15\n" + "vmla.f32 q11, q6, d3[1] @ w4 * " + "inr17\n" + "vld1.32 {d9}, [%[r3]] @ load r3, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w5 * " + "inr12\n" + "vmla.f32 q9, q7, d2[0] @ w5 * " + "inr14\n" + "vmla.f32 q10, q7, d3[0] @ w5 * " + "inr16\n" + "vmla.f32 q11, q7, d8[0] @ w5 * " + "inr18\n" + + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " + "- 32, to start address\n" + + /* mul r3, with w3, w4, w5 */ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w3 * " + "inr30\n" + "vmla.f32 q13, q5, d5[0] @ w3 * " + "inr32\n" + "vmla.f32 q14, q5, d6[0] @ w3 * " + "inr34\n" + "vmla.f32 q15, q5, d7[0] @ w3 * " + "inr36\n" + "vld1.32 {d8}, [%[r2]] @ load r2, " + "9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * " + "inr31\n" + "vmla.f32 q13, q6, d5[1] @ w4 * " + "inr33\n" + "vmla.f32 q14, q6, d6[1] @ w4 * " + "inr35\n" + "vmla.f32 q15, q6, d7[1] @ w4 * " + "inr37\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " + "w7, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w5 * " + "inr32\n" + "vmla.f32 q13, q7, d6[0] @ w5 * " + "inr34\n" + "vmla.f32 q14, q7, d7[0] @ w5 * " + "inr36\n" + "vmla.f32 q15, q7, d9[0] @ w5 * " + "inr38\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " + "to q7\n" + + /* mul r2, with w6, w7, w8 */ + "vmla.f32 q8, q5, d0[0] @ w6 * " + "inr20\n" + "vmla.f32 q9, q5, d1[0] @ w6 * " + "inr22\n" + "vmla.f32 q10, q5, d2[0] @ w6 * " + "inr24\n" + "vmla.f32 q11, q5, d3[0] @ w6 * " + "inr26\n" + "vld1.32 {d4-d7}, [%[r4]]! @ load r4, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w7 * " + "inr21\n" + "vmla.f32 q9, q6, d1[1] @ w7 * " + "inr23\n" + "vmla.f32 q10, q6, d2[1] @ w7 * " + "inr25\n" + "vmla.f32 q11, q6, d3[1] @ w7 * " + "inr27\n" + "vld1.32 {d9}, [%[r4]] @ load r4, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w8 * " + "inr22\n" + "vmla.f32 q9, q7, d2[0] @ w8 * " + "inr24\n" + "vmla.f32 q10, q7, d3[0] @ w8 * " + "inr26\n" + "vmla.f32 q11, q7, d8[0] @ w8 * " + "inr28\n" + + "sub %[wc0], %[wc0], #144 @ wc0 - " + "144 to start address\n" + + /* mul r4, with w6, w7, w8 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w3 * " + "inr40\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " + "r00, r01, c0~c3\n" + "vmla.f32 q13, q5, d5[0] @ w3 * " + "inr42\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " + "r02, r03, c0~c3\n" + "vmla.f32 q14, q5, d6[0] @ w3 * " + "inr44\n" + "vmla.f32 q15, q5, d7[0] @ w3 * " + "inr46\n" + "vld1.32 {d8}, [%[r0]] @ load r0, " + "9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * " + "inr41\n" + "vmla.f32 q13, q6, d5[1] @ w4 * " + "inr43\n" + "vmla.f32 q14, q6, d6[1] @ w4 * " + "inr45\n" + "vmla.f32 q15, q6, d7[1] @ w4 * " + "inr47\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w5 * " + "inr42\n" + "vmla.f32 q13, q7, d6[0] @ w5 * " + "inr44\n" + "vmla.f32 q14, q7, d7[0] @ w5 * " + "inr46\n" + "vmla.f32 q15, q7, d9[0] @ w5 * " + "inr48\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " + "r10, r11, c0~c3\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " + "r12, r13, c0~c3\n" + + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + "subs %[cnt], #1 @ loop " + "count--\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), [wc0] "+r"(wc0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + + slidingwindow_writeout_c4_fp32(pre_out, dout_batch, c, c + hout_c_block, + h, h + h_kernel, 0, wout_round, chout, + hout, wout, relu, ptr_write); + } + +#pragma omp parallel for + for (int c = 0; c < c_remain; ++c) { +#ifdef USE_OPENMP + float *pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float *pre_out = pre_din + pre_in_size; +#endif + + const float *block_inr0 = cblock_inr0; + const float *block_inr1 = cblock_inr1; + const float *block_inr2 = cblock_inr2; + const float *block_inr3 = cblock_inr3; + const float *block_inr4 = cblock_inr4; + + //! get weights ptr of remained + const float *weight_c = weights + c_round_down * w_stride; + + //! fill bias to one channel + const float *bias_ptr = ptr_zero; + if (bias != nullptr) { + bias_ptr = bias + c_round_down + c; + } + slidingwindow_fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float *wc0 = weight_c; + + const float *inr0 = block_inr0; + const float *inr1 = block_inr1; + const float *inr2 = block_inr2; + const float *inr3 = block_inr3; + const float *inr4 = block_inr4; + + float *pre_out0 = pre_out + hk * wout_round; + float *pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[c + 4]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[c + 8]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[c + 12]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[c + 16]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[c + 20]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[c + 24]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[c + 28]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[c + 32]); // w8, v31 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + const float *r4 = inr4; + + int cnt = w_loop; + asm volatile( + "ldr q21, [%[ptr_out0]] \n" /* load outr00, outr01, + outr02, outr03*/ + + "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "2: \n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11, + outr12, outr13*/ + + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2, + 4, 6]*/ + "fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2, + 4, 6]*/ + + "ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/ + + /* r2 mul w6, get out r0*/ + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2, + 4, 6]*/ + "ldr d11, [%[r1]] \n" /* load input r1, 9th + element*/ + + /* shift left 1 */ + "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ + "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ + + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3, + 5, 7]*/ + "fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3, + 5, 7]*/ + + "ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/ + + /* r2 mul w7, get out r0 */ + "fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1, + 3, 5, 7]*/ + + "ldr d13, [%[r3]] \n" /* load input r3, 9th + element*/ + + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4, + 6, 8]*/ + "fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4, + 6, 8]*/ + + "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ + + /* r2, mul w8, get out r0 */ + "fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2, + 4, 6, 8]*/ + + "ldr d14, [%[r4]] \n" /* load input r4, 9th + element*/ + + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2, + 4, 6]*/ + "fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2, + 4, 6]*/ + + /* shift left 1 */ + "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ + "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ + + "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ + + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3, + 5, 7]*/ + "fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3, + 5, 7]*/ + + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/ + "fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/ + + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/ + + /* r4, mul w6, get out r1 */ + "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2, + 4, 6]*/ + + "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ + "ldr q21, [%[ptr_out0]] \n" /* load outr0*/ + + /* r4, mul w7, get out r1 */ + "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3, + 5, 7]*/ + + /* r4, mul w8, get out r1 */ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4, + 6, 8]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3), + [w4] "w"(w4), [w5] "w"(w5), [w6] "w"(w6), [w7] "w"(w7), + [w8] "w"(w8) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v21", "v22"); + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < chin; ++i) { + float *ptr_out0 = pre_out0; + float *ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[12] = {wc0[c], wc0[c + 4], wc0[c + 8], 0.f, + wc0[c + 12], wc0[c + 16], wc0[c + 20], 0.f, + wc0[c + 24], wc0[c + 28], wc0[c + 32], 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 4); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 8); // w6, w7, w8, q2 + + const float *r0 = inr0; + const float *r1 = inr1; + const float *r2 = inr2; + const float *r3 = inr3; + const float *r4 = inr4; + + int cnt = w_loop / 2; + if (cnt > 0) { + asm volatile( + /* main loop */ + "0: @ " + "main loop\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " + "or01\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " + "or11\n" + "vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 " + "float, interleave\n" + "vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 " + "float, interleave\n" + "vld1.32 {d22}, [%[r2]] @ load 16th " + "float\n" + + /* r2 * w2, r2 * w0, get or0, or1 */ + "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, " + "9, 11, 13, 15\n" + "vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 " + "float, interleave\n" + "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, " + "9, 11, 13, 15\n" + + "vext.32 q4, q3, q5, #1 @ r2, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r2, shift " + "left 1, get 10, 12, 14, 16\n" + + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, " + "8, 10, 12, 14\n" + "vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 " + "float, interleave\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, " + "8, 10, 12, 14\n" + + "vld1.32 {d22}, [%[r0]] @ load 16th " + "float\n" + + "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, " + "2, 4, 6, 8\n" + "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, " + "2, 4, 6, 8\n" + "vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 " + "float, interleave\n" + "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, " + "10, 12, 14, 16\n" + "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, " + "10, 12, 14, 16\n" + "vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 " + "float, interleave\n" + + /* r0 * w0, get or0, r3 * w1, get or1*/ + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, " + "9, 11, 13, 15\n" + "vext.32 q8, q7, q9, #1 @ r0, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q10, q9, q11, #1 @ r0, shift " + "left 1, get 10, 12, 14, 16\n" + "vld1.32 {d22}, [%[r3]] @ load 16th " + "float\n" + "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, " + "9, 11, 13, 15\n" + + "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, " + "8, 10, 12, 14\n" + "vext.32 q4, q3, q5, #1 @ r3, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r3, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, " + "8, 10, 12, 14\n" + + "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " + "2, 4, 6, 8\n" + "vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 " + "float, interleave\n" + "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, " + "10, 12, 14, 16\n" + "vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 " + "float, interleave\n" + "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, " + "2, 4, 6, 8\n" + "vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 " + "float, interleave\n" + "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, " + "10, 12, 14, 16\n" + "vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 " + "float, interleave\n" + + "vld1.32 {d22}, [%[r1]] @ load 16th " + "float\n" + + /* r1 * w1, get or0, r4 * w2, get or1 */ + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, " + "9, 11, 13, 15\n" + "vext.32 q8, q7, q9, #1 @ r1, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q10, q9, q11, #1 @ r1, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, " + "9, 11, 13, 15\n" + "vld1.32 {d22}, [%[r4]] @ load 16th " + "float\n" + + "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, " + "8, 10, 12, 14\n" + "vext.32 q4, q3, q5, #1 @ r1, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r1, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, " + "8, 10, 12, 14\n" + + "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, " + "2, 4, 6, 8\n" + "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, " + "10, 12, 14, 16\n" + "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, " + "2, 4, 6, 8\n" + "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, " + "10, 12, 14, 16\n" + + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" + + "subs %[cnt], #1 @loop count " + "-1\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), + [r2] "+r"(r2), [r3] "+r"(r3), [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc", "memory", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15"); + } + //! deal with remain wout + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[4] + r1[1] * w_tmp[5] + r1[2] * w_tmp[6] + + r2[0] * w_tmp[8] + r2[1] * w_tmp[9] + r2[2] * w_tmp[10]; + + ptr_out0[1] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r1[4] * w_tmp[6] + + r2[2] * w_tmp[8] + r2[3] * w_tmp[9] + r2[4] * w_tmp[10]; + + ptr_out0[2] += + r0[4] * w_tmp[0] + r0[5] * w_tmp[1] + r0[6] * w_tmp[2] + + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r1[6] * w_tmp[6] + + r2[4] * w_tmp[8] + r2[5] * w_tmp[9] + r2[6] * w_tmp[10]; + + ptr_out0[3] += + r0[6] * w_tmp[0] + r0[7] * w_tmp[1] + r0[8] * w_tmp[2] + + r1[6] * w_tmp[4] + r1[7] * w_tmp[5] + r1[8] * w_tmp[6] + + r2[6] * w_tmp[8] + r2[7] * w_tmp[9] + r2[8] * w_tmp[10]; + + ptr_out1[0] += + r2[0] * w_tmp[0] + r2[1] * w_tmp[1] + r2[2] * w_tmp[2] + + r3[0] * w_tmp[4] + r3[1] * w_tmp[5] + r3[2] * w_tmp[6] + + r4[0] * w_tmp[8] + r4[1] * w_tmp[9] + r4[2] * w_tmp[10]; + + ptr_out1[1] += + r2[2] * w_tmp[0] + r2[3] * w_tmp[1] + r2[4] * w_tmp[2] + + r3[2] * w_tmp[4] + r3[3] * w_tmp[5] + r3[4] * w_tmp[6] + + r4[2] * w_tmp[8] + r4[3] * w_tmp[9] + r4[4] * w_tmp[10]; + + ptr_out1[2] += + r2[4] * w_tmp[0] + r2[5] * w_tmp[1] + r2[6] * w_tmp[2] + + r3[4] * w_tmp[4] + r3[5] * w_tmp[5] + r3[6] * w_tmp[6] + + r4[4] * w_tmp[8] + r4[5] * w_tmp[9] + r4[6] * w_tmp[10]; + + ptr_out1[3] += + r2[6] * w_tmp[0] + r2[7] * w_tmp[1] + r2[8] * w_tmp[2] + + r3[6] * w_tmp[4] + r3[7] * w_tmp[5] + r3[8] * w_tmp[6] + + r4[6] * w_tmp[8] + r4[7] * w_tmp[9] + r4[8] * w_tmp[10]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + slidingwindow_writeout_c1_fp32( + pre_out, dout_batch, c + c_round_down, c + c_round_down + 1, h, + h + h_kernel, 0, wout_round, chout, hout, wout, relu, ptr_write); + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/slidingwindow_conv3x3.h b/src/operators/math/slidingwindow_conv3x3.h index cbcdbc170f1c01866fe402447e07b6ab189a535b..9ef8fd2b3fff4c449eea3b41013862dd76c5d3c0 100644 --- a/src/operators/math/slidingwindow_conv3x3.h +++ b/src/operators/math/slidingwindow_conv3x3.h @@ -33,6 +33,17 @@ void SlidingwindowConv3x3s2(const framework::Tensor *input, const std::vector &paddings, framework::Tensor *output); +template +void SlidingwindowConv3x3s1Faster(const framework::Tensor *input, + framework::Tensor *filter, + const std::vector &paddings, + framework::Tensor *output); + +template +void SlidingwindowConv3x3s2Faster(const framework::Tensor *input, + framework::Tensor *filter, + const std::vector &paddings, + framework::Tensor *output); } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/slidingwindow_utils.cpp b/src/operators/math/slidingwindow_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd20612482703bd6f772a07a249a7a9f5c4fdb29 --- /dev/null +++ b/src/operators/math/slidingwindow_utils.cpp @@ -0,0 +1,365 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/math/slidingwindow_utils.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void slidingwindow_fill_bias(float* dout, const float* bias, int ch_num, + int ch_size) { + for (int j = 0; j < ch_num; j++) { + float32x4_t vb = vdupq_n_f32(bias[j]); + int i = 0; + for (; i < ch_size - 3; i += 4) { + vst1q_f32(dout + i, vb); + } + for (; i < ch_size; i++) { + dout[i] = bias[j]; + } + dout += ch_size; + } +} + +/* write result in outputs + * input din: [n, c, h, w], output dout: [n, c, h, w] + */ +void slidingwindow_writeout_c1_fp32(const float* din, float* dout, int cs, + int ce, int hs, int he, int ws, int we, + int channel, int height, int width, + bool flag_relu, float* trash_ptr) { + if (cs > channel) { + return; + } + + const int c1 = 1; + const int w4 = 4; + + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int w_round = we - ws; + int cnt = (width - ws) / w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + const float* din_hei_ptr = ptr_din + i * w_round * c1; + if (cnt > 0) { + int cnt_loop = cnt; + if (flag_relu) { +#ifdef __aarch64__ + asm volatile( + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop */ + "fmax v1.4s, v0.4s, v20.4s \n" /* relu */ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + "str q1, [%[doutc0r0]], #16 \n" /* store c0r0 */ + "bne 1b \n" /* jump to main loop */ + : [doutc0r0] "+r"(doutc0_ptr), [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v20"); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, c1r0, " + "c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + + "vmax.f32 q1, q0, q15 @ relu\n" + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" + + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q15"); +#endif + } else { +#ifdef __aarch64__ + asm volatile( + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "1: \n" /* main loop */ + "str q0, [%[doutc0r0]], #16 \n" /* store c2r0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "bne 1b \n" /* jump to main loop */ + + : [doutc0r0] "+r"(doutc0_ptr), [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0"); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, c0r1, " + "c0r2, c0r3\n" + "1: @ main loop\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0"); +#endif + } + } + if (we > width) { + int offset = i * w_round * c1 + c1 * w4 * cnt; + din_hei_ptr = ptr_din + offset; + int j = we - w4; + if (flag_relu) { + for (; j < width; ++j) { + *(doutc0_ptr++) = std::max(din_hei_ptr[0], 0.f); + din_hei_ptr++; + } + } else { + for (; j < width; ++j) { + *(doutc0_ptr++) = *(din_hei_ptr++); + } + } + } + } +} + +/* write result in outputs + * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] + */ +void slidingwindow_writeout_c4_fp32(const float* din, float* dout, int cs, + int ce, int hs, int he, int ws, int we, + int channel, int height, int width, + bool flag_relu, float* trash_ptr) { + const int c4 = 4; + const int w4 = 4; + const int w_round = we - ws; + const int ch_n = ce - cs; + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int cnt = (width - ws) / w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + const float* din_hei_ptr = ptr_din + i * w_round * ch_n; + if (cnt > 0) { + int cnt_loop = cnt; + if (flag_relu) { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop */ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1 */ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1 */ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3 */ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10 */ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10 */ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11 */ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11 */ + "fmax v16.4s, v16.4s, v20.4s \n" /* relu */ + "fmax v17.4s, v17.4s, v20.4s \n" /* relu */ + "fmax v18.4s, v18.4s, v20.4s \n" /* relu */ + "fmax v19.4s, v19.4s, v20.4s \n" /* relu */ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0 */ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0 */ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0 */ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0 */ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + "bne 1b \n" /* jump to main loop */ + + : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v16", "v17", "v18", "v19", "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" + "vmov.u32 q15, #0 @ dump zero \n" + "1: @ main loop \n" + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " + "\n" + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " + "\n" + + "vswp d1, d4 @ swap data\n" + "vswp d3, d6 @ swap data\n" + + "vmax.f32 q0, q0, q15 @ relu\n" + "vmax.f32 q1, q1, q15 @ relu\n" + "vmax.f32 q2, q2, q15 @ relu\n" + "vmax.f32 q3, q3, q15 @ relu\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } else { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "1: \n" /* main loop */ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1 */ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1 */ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3 */ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10 */ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10 */ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11 */ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11 */ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0 */ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0 */ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0 */ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0 */ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + "bne 1b \n" /* jump to main loop */ + + : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", + "v18", "v19"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" + "1: @ main loop \n" + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " + "\n" + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " + "\n" + + "vswp d1, d4 @ swap data\n" + "vswp d3, d6 @ swap data\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3"); +#endif + } + } + if (we > width) { + int offset = i * w_round * c4 + c4 * w4 * cnt; + din_hei_ptr = ptr_din + offset; + int j = we - w4; + if (flag_relu) { + for (; j < width; ++j) { + *(doutc0_ptr++) = std::max(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = std::max(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = std::max(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = std::max(din_hei_ptr[3], 0.f); + din_hei_ptr += w4; + } + } else { + for (; j < width; ++j) { + *(doutc0_ptr++) = din_hei_ptr[0]; + *(doutc1_ptr++) = din_hei_ptr[1]; + *(doutc2_ptr++) = din_hei_ptr[2]; + *(doutc3_ptr++) = din_hei_ptr[3]; + din_hei_ptr += w4; + } + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/slidingwindow_utils.h b/src/operators/math/slidingwindow_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6db22bcf5fef126a8a830a0c30da87331fea5e0a --- /dev/null +++ b/src/operators/math/slidingwindow_utils.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "framework/tensor.h" + +#if __ARM_NEON +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +/* preprocessing weights + * input weights: [chout, chin/ group, kh, kw] --> outputs weights: [chout / n, + * chin/ group, kh, kw, n] + */ +template +void slidingwindow_transform_weight(const framework::Tensor& weight, + framework::Tensor* output) { + int chout = weight.dims()[0]; + int chin = weight.dims()[1]; + int kernel_size = weight.dims()[2] * weight.dims()[3]; + const int n = 4; + int cround = (chout + n - 1) / n * n; + const dtype* din = weight.data(); + dtype* dout = output->mutable_data({cround, chin, 3, 3}); + int c_loop = chout / n; + int chout_round = (chout + n - 1) / n; + int win_stride = chin * kernel_size; + int wout_stride = n * win_stride; + int co = 0; + for (; co < c_loop; ++co) { + dtype* dout_c = dout + co * wout_stride; + const dtype* din_array[n]; + din_array[0] = din + co * wout_stride; + for (int i = 1; i < n; i++) { + din_array[i] = din_array[i - 1] + win_stride; + } + for (int ci = 0; ci < chin; ++ci) { + for (int k = 0; k < kernel_size; ++k) { + for (int i = 0; i < n; i++) { + *(dout_c++) = *(din_array[i]++); + } + } + } + } + // pad final chout + if (chout_round > c_loop) { + dtype* dout_c = dout + c_loop * wout_stride; + const dtype* din_array[n]; + din_array[0] = din + c_loop * wout_stride; + for (int i = 1; i < n; i++) { + din_array[i] = din_array[i - 1] + win_stride; + } + // deal remain + int cremain = chout_round * n - chout; + for (int i = 1; i <= cremain; i++) { + din_array[n - i] = din_array[0]; + } + for (int ci = 0; ci < chin; ++ci) { + for (int k = 0; k < kernel_size; ++k) { + for (int i = 0; i < n; i++) { + *(dout_c++) = *(din_array[i]++); + } + } + } + } +} + +/* preprocessing inputs + * input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws] + * n = he - hs + */ +template +void slidingwindow_prepack_input(const dtype* din, dtype* dout, int cs, int ce, + int hs, int he, int ws, int we, int channel, + int width, int height, dtype* zero_ptr) { + int n = he - hs; + int w0 = ws < 0 ? 0 : ws; + int w1 = we > width ? width : we; + + int size_w = we - ws; + int size_wc_len = size_w * channel; + int size_c = width * height; + + int valid_w = w1 - w0; + size_t valid_w_byte = valid_w * sizeof(dtype); + + dtype* out_array[n]; + out_array[0] = dout; + for (int i = 1; i < n; i++) { + out_array[i] = out_array[i - 1] + size_wc_len; + } + + for (int c = 0; c < channel; ++c) { + int j = 0; + // valid height + for (int i = hs; i < he; i++) { + // get address + const dtype* in_array; + if (i < 0 || i >= height) { + in_array = zero_ptr; + } else { + in_array = din + i * width; + } + + for (int w = ws; w < w0; ++w) { + *(out_array[j]++) = 0.f; + } + memcpy(out_array[j], in_array, valid_w_byte); + out_array[j] += valid_w; + for (int w = w1; w < we; ++w) { + *(out_array[j]++) = 0.f; + } + j++; + } + din += size_c; + } +} + +inline void slidingwindow_fill_bias(float* dout, const float* bias, int size) { + float32x4_t vb = vld1q_f32(bias); + int cnt = size / 4; + for (int i = 0; i < cnt; ++i) { + vst1q_f32(dout, vb); + dout += 4; + } +} + +void slidingwindow_fill_bias(float* dout, const float* bias, int ch_num, + int ch_size); + +void slidingwindow_writeout_c1_fp32(const float* din, float* dout, int cs, + int ce, int hs, int he, int ws, int we, + int channel, int height, int width, + bool flag_relu, float* trash_ptr); + +void slidingwindow_writeout_c4_fp32(const float* din, float* dout, int cs, + int ce, int hs, int he, int ws, int we, + int channel, int height, int width, + bool flag_relu, float* trash_ptr); +} // namespace math +} // namespace operators +} // namespace paddle_mobile