未验证 提交 4b253569 编写于 作者: T tensor-tang 提交者: GitHub

[Lite] enable fc kernel (#17674)

* add fc unit test

* refine eigen fc
add cpu info, arm context
init packed sgemm

* enable packed sgemm

* add arm math

* pass fc ut

* follow comments
上级 e170ea03
...@@ -16,6 +16,8 @@ if(NOT ANDROID) ...@@ -16,6 +16,8 @@ if(NOT ANDROID)
return() return()
endif() endif()
add_definitions(-DLITE_WITH_ANDROID)
if(NOT DEFINED ANDROID_NDK) if(NOT DEFINED ANDROID_NDK)
set(ANDROID_NDK $ENV{NDK_ROOT}) set(ANDROID_NDK $ENV{NDK_ROOT})
if(NOT ANDROID_NDK) if(NOT ANDROID_NDK)
......
...@@ -118,6 +118,7 @@ endfunction() ...@@ -118,6 +118,7 @@ endfunction()
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(arm)
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(operators) add_subdirectory(operators)
......
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/funcs.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void fill_bias_fc<float>(float *tensor, const float *bias, const int num,
const int channel) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias;
float *ptr_out = tensor + j * channel;
float32x4_t vout1;
float32x4_t vout2;
float32x4_t vout3;
float32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
vout1 = vaddq_f32(vin1, vb1);
vout2 = vaddq_f32(vin2, vb2);
vout3 = vaddq_f32(vin3, vb3);
vout4 = vaddq_f32(vin4, vb4);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.f32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (; remain > 0; remain--) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
template <>
void fill_bias_fc<int>(int *tensor, const int *bias, const int num,
const int channel) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = tensor + j * channel;
int32x4_t vout1;
int32x4_t vout2;
int32x4_t vout3;
int32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
vout1 = vaddq_s32(vin1, vb1);
vout2 = vaddq_s32(vin2, vb2);
vout3 = vaddq_s32(vin3, vb3);
vout4 = vaddq_s32(vin4, vb4);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.s32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (; remain > 0; remain--) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <cmath>
#include "paddle/fluid/lite/arm/math/packed_sgemm.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel);
template <typename T>
void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W;
if (b) {
Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>> B(b, w_w);
Out = Out.array().rowwise() + B.array();
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
constexpr int MBLOCK = 8;
constexpr int NBLOCK = 12;
constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) { return MBLOCK; }
#else
constexpr int MBLOCK_A73 = 4;
constexpr int MBLOCK_OTH = 6;
constexpr int NBLOCK = 8;
constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) {
if (arch == kA73) {
return MBLOCK_A73;
} else {
return MBLOCK_OTH;
}
}
#endif // __aarch64__
void prepackA(float* out, const float* in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax, bool is_trans,
ARMContext* ctx);
void prepackA(TensorLite* tout, const TensorLite& tin, int m, int k, int group,
bool is_trans, ARMContext* ctx);
void sgemm_prepack(const float* A_packed, const float* B, const float* bias,
float* C, int M, int N, int K, bool is_bias, bool is_relu,
bool is_transB, ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -23,7 +23,8 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_l ...@@ -23,7 +23,8 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_l
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
cc_library(context_lite SRCS context.cc DEPS any_lite) cc_library(cpu_info_lite SRCS cpu_info.cc)
cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite}) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite})
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
......
...@@ -12,8 +12,317 @@ ...@@ -12,8 +12,317 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//
// Created by chunwei on 19-2-22.
//
#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/cpu_info.h"
#ifdef LITE_WITH_ANDROID
#include <sys/syscall.h>
#include <unistd.h>
#endif
#if __APPLE__
#include "TargetConditionals.h"
#if TARGET_OS_IPHONE
#include <mach/machine.h>
#include <sys/sysctl.h>
#include <sys/types.h>
#endif // TARGET_OS_IPHONE
#endif // __APPLE__
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
void ARMContext::SetCache(int l1size, int l2size, int l3size) {
DeviceInfo& dev = DeviceInfo::Global();
int cpu_count = arm_get_cpucount();
dev.L1_cache_.resize(cpu_count);
dev.L2_cache_.resize(cpu_count);
dev.L3_cache_.resize(cpu_count);
for (int i = 0; i < cpu_count; ++i) {
dev.L1_cache_[i] = l1size;
dev.L2_cache_[i] = l2size;
dev.L3_cache_[i] = l3size;
}
workspace_.Resize({2 * (l1size + l2size)});
}
ARMContext::ARMContext() {
active_ids_ = {0};
mode_ = LITE_POWER_HIGH;
DeviceInfo& dev = DeviceInfo::Global();
workspace_.Resize(
{static_cast<int64_t>(dev.L2_cache_[active_ids_[0]] / sizeof(float))});
#ifdef TARGET_IOS
arch_ = APPLE; // use 6x8
#else
if (dev.big_core_ids_.size() > 0) {
arch_ = dev.archs_[dev.big_core_ids_[0]];
}
#endif
}
PowerMode ARMContext::mode() const { return mode_; }
int ARMContext::threads() const { return active_ids_.size(); }
ARMContext::ARMContext(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
arch_ = ctx.arch_;
count_ = ctx.count_;
}
ARMContext& ARMContext::operator=(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
arch_ = ctx.arch_;
count_ = ctx.count_;
return *this;
}
void ARMContext::BindDev() {
#ifdef USE_OPENMP
int num_threads = active_ids_.size();
omp_set_num_threads(num_threads);
#ifdef LITE_WITH_ANDROID
std::vector<int> ssarets;
for (int j = 0; j < num_threads; ++j) {
ssarets.push_back(0);
}
#pragma omp parallel for
for (int i = 0; i < num_threads; i++) {
ssarets[i] = set_sched_affinity(active_ids_);
}
for (int i = 0; i < num_threads; i++) {
if (ssarets[i] != 0) {
LOGE("set cpu affinity failed, cpuID: %d\n", active_ids_[i]);
return;
}
}
#endif // LITE_WITH_ANDROID
#else // USE_OPENMP
#ifdef LITE_WITH_ANDROID
std::vector<int> cpuid1;
cpuid1.push_back(active_ids_[0]);
int ssaret = set_sched_affinity(cpuid1);
if (ssaret != 0) {
printf("set cpu affinity failed, cpuID: %d\n", active_ids_[0]);
return;
}
#endif // LITE_WITH_ANDROID
#endif // USE_OPENMP
}
void ARMContext::SetRunMode(PowerMode mode, int threads) {
DeviceInfo& dev = DeviceInfo::Global();
int big_core_size = dev.big_core_ids_.size();
int small_core_size = dev.little_core_ids_.size();
if (threads > big_core_size + small_core_size) {
threads = big_core_size + small_core_size;
}
#ifdef USE_OPENMP
count_++;
int shift_num = (count_ / 10) % big_core_size;
switch (mode) {
case LITE_POWER_FULL:
mode_ = mode;
active_ids_.clear();
for (int i = 0; i < threads; ++i) {
if (i < big_core_size) {
active_ids_.push_back(dev.big_core_ids_[i]);
} else {
active_ids_.push_back(dev.little_core_ids_[i - big_core_size]);
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_HIGH:
active_ids_.clear();
if (big_core_size > 0) {
mode_ = LITE_POWER_HIGH;
if (threads > big_core_size) {
LOGE("threads: %d, exceed the big cores size: %d\n", threads,
big_core_size);
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
} else {
mode_ = LITE_POWER_LOW;
LOGE("HIGH POWER MODE is not support, switch to little cores\n");
if (threads > small_core_size) {
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_LOW:
active_ids_.clear();
if (small_core_size > 0) {
mode_ = LITE_POWER_LOW;
if (threads > small_core_size) {
LOGW("threads: %d, exceed the little cores size: %d\n", threads,
small_core_size);
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
} else {
mode_ = LITE_POWER_HIGH;
LOGW("LOW POWER MODE is not support, switch to big cores\n");
if (threads > big_core_size) {
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_NO_BIND:
mode_ = LITE_POWER_NO_BIND;
active_ids_.clear();
if (threads > dev.core_ids_.size()) {
active_ids_.resize(dev.core_ids_.size());
} else {
active_ids_.resize(threads);
}
break;
case LITE_POWER_RAND_HIGH:
active_ids_.clear();
if (big_core_size > 0) {
mode_ = LITE_POWER_RAND_HIGH;
if (threads > big_core_size) {
LOGW("threads: %d, exceed the big cores size: %d\n", threads,
big_core_size);
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(
dev.big_core_ids_[(i + shift_num) % big_core_size]);
}
}
} else {
mode_ = LITE_POWER_LOW;
LOGW("HIGH POWER MODE is not support, switch to little cores\n");
if (threads > small_core_size) {
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_RAND_LOW:
active_ids_.clear();
if (small_core_size > 0) {
mode_ = LITE_POWER_RAND_LOW;
if (threads > small_core_size) {
LOGW("threads: %d, exceed the little cores size: %d\n", threads,
small_core_size);
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(
dev.little_core_ids_[(i + shift_num) % small_core_size]);
}
}
} else {
mode_ = LITE_POWER_HIGH;
LOGW("LOW POWER MODE is not support, switch to big cores\n");
if (threads > big_core_size) {
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
}
//! fix multi-threads LITE_POWER_HIGH mode
if (mode_ == LITE_POWER_NO_BIND || threads > 1) {
int threads = active_ids_.size();
omp_set_num_threads(threads);
} else {
if (check_online(active_ids_)) {
BindDev();
} else {
LOG(ERROR) << "core id " << active_ids_[0]
<< " is offline, switch to NO BIND MODE";
int threads = active_ids_.size();
omp_set_num_threads(threads);
}
}
#else
if (big_core_size > 0) {
active_ids_ = {dev.big_core_ids_[0]};
} else {
active_ids_ = {0};
}
#endif
//! alloc memory for sgemm in this context
int temp_mem_size =
DeviceInfo::Global().L2_cache_[active_ids_[0]] / sizeof(float);
workspace_.Resize({temp_mem_size});
arch_ = DeviceInfo::Global().archs_[active_ids_[0]];
}
ARMArch ARMContext::arch() const { return arch_; }
void ARMContext::SetArch(ARMArch arch) { arch_ = arch; }
int ARMContext::l1_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L1_cache_[active_ids_[0]];
}
int ARMContext::l2_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L2_cache_[active_ids_[0]];
}
int ARMContext::l3_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L3_cache_[active_ids_[0]];
}
bool ARMContext::ExtendWorkspace(DDimLite dims) {
auto count = dims.product();
auto old = workspace_.dims();
if (count == old.product()) {
return false;
}
workspace_.Resize(
{static_cast<int64_t>(count + l2_cache_size() / sizeof(float))});
return true;
}
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
...@@ -34,7 +36,44 @@ namespace lite { ...@@ -34,7 +36,44 @@ namespace lite {
struct HostContext {}; struct HostContext {};
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
struct ARMContext {};
struct ARMContext {
public:
ARMContext();
ARMContext(PowerMode mode, int threads);
ARMContext(const ARMContext& ctx);
ARMContext& operator=(const ARMContext& ctx);
void SetRunMode(PowerMode mode, int threads);
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch);
void BindDev();
PowerMode mode() const;
int threads() const;
ARMArch arch() const;
template <typename T>
T* workspace_data() {
return workspace_.mutable_data<T>();
}
int l1_cache_size() const;
int l2_cache_size() const;
int l3_cache_size() const;
bool ExtendWorkspace(DDimLite dims);
private:
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core,
// LITE_POWER_FULL stands for using all cores
ARMArch arch_;
PowerMode mode_;
std::vector<int> active_ids_;
TensorLite workspace_;
int64_t count_{0};
};
#endif #endif
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/cpu_info.h"
#include <cstdarg>
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
void DeviceInfo::get_info(DeviceInfo* dev) {
set_default_cache(dev);
dev->compute_core_num_ = arm_get_cpucount();
dev->max_memory_ = arm_get_meminfo();
// get max freq
#ifdef LITE_WITH_ANDROID
std::vector<int> max_freq(dev->compute_core_num_);
for (int i = 0; i < dev->compute_core_num_; ++i) {
max_freq[i] = get_max_freq_khz(i) / 1000;
}
std::string cpu_name = arm_get_cpu_name();
if (get_cpu_info_from_name(dev, cpu_name) != true) {
arm_sort_cpuid_by_max_frequency(dev->compute_core_num_, &dev->core_ids_,
max_freq, &dev->cluster_ids_);
dev->big_core_ids_.clear();
dev->little_core_ids_.clear();
for (int i = 0; i < dev->cluster_ids_.size(); ++i) {
if (dev->cluster_ids_[i] == 0) {
dev->big_core_ids_.push_back(dev->core_ids_[i]);
} else {
dev->little_core_ids_.push_back(dev->core_ids_[i]);
}
}
arm_get_cpu_arch(&dev->archs_);
}
LOG(INFO) << "ARM multiprocessors number: " << dev->compute_core_num_;
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << "ARM multiprocessors ID: " << dev->core_ids_[i]
<< ", frequence: " << max_freq[i]
<< ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]]
<< ", CPU ARCH: A" << dev->archs_[i];
}
LOG(INFO) << "L1 DataCache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB";
}
LOG(INFO) << "L2 Cache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB";
}
LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB";
dev->max_freq_ = max_freq[0];
for (int j = 1; j < dev->compute_core_num_; ++j) {
if (dev->max_freq_ < max_freq[j]) {
dev->max_freq_ = max_freq[j];
}
}
#elif defined(TARGET_IOS)
arm_get_cpu_arch(&dev->archs_);
#endif
}
// cache_id : 0 -> L1, 1 -> L2, 2 -> L3
void set_cache_info(DeviceInfo* cpu_info, int cache_id, int argc, ...) {
va_list arg_ptr;
va_start(arg_ptr, argc);
std::vector<int>* cache;
switch (cache_id) {
case 0:
cache = &cpu_info->L1_cache_;
break;
case 1:
cache = &cpu_info->L2_cache_;
break;
case 2:
cache = &cpu_info->L3_cache_;
break;
default:
break;
}
int core_num = cpu_info->compute_core_num_;
cache->resize(core_num);
if (argc == 1) {
int cache_size = va_arg(arg_ptr, int);
for (int i = 0; i < core_num; ++i) {
(*cache)[i] = cache_size;
}
} else {
int big_core_num = cpu_info->big_core_ids_.size();
int little_core_num = cpu_info->little_core_ids_.size();
int big_core_cache_size = va_arg(arg_ptr, int);
int little_core_cache_size = va_arg(arg_ptr, int);
for (int i = 0; i < big_core_num; ++i) {
(*cache)[cpu_info->big_core_ids_[i]] = big_core_cache_size;
}
for (int i = 0; i < little_core_num; ++i) {
(*cache)[cpu_info->little_core_ids_[i]] = little_core_cache_size;
}
}
va_end(arg_ptr);
}
void set_arch_info(DeviceInfo* cpu_info, int argc, ...) {
va_list arg_ptr;
va_start(arg_ptr, argc);
int core_num = cpu_info->compute_core_num_;
cpu_info->archs_.resize(core_num);
if (argc == 1) {
ARMArch arch = (ARMArch)va_arg(arg_ptr, int);
for (int i = 0; i < core_num; ++i) {
cpu_info->archs_[i] = arch;
}
} else {
ARMArch big_core_arch = (ARMArch)va_arg(arg_ptr, int);
ARMArch little_core_arch = (ARMArch)va_arg(arg_ptr, int);
int big_core_num = cpu_info->big_core_ids_.size();
int little_core_num = cpu_info->little_core_ids_.size();
for (int i = 0; i < big_core_num; ++i) {
cpu_info->archs_[cpu_info->big_core_ids_[i]] = big_core_arch;
}
for (int i = 0; i < little_core_num; ++i) {
cpu_info->archs_[cpu_info->little_core_ids_[i]] = little_core_arch;
}
}
va_end(arg_ptr);
}
bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name) {
/* Snapdragon */
if (hardware_name.find("SDM845") != std::string::npos) { // 845
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA75, kA55);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 256 * 1024, 128 * 1024);
set_cache_info(cpu_info, 2, 1, 2048 * 1024);
return true;
} else if (hardware_name.find("SDM710") != std::string::npos) { // 710
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA75, kA55);
return true;
} else if (hardware_name.find("MSM8998") != std::string::npos) { // 835
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
set_cache_info(cpu_info, 0, 2, 64 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024,
/*real cache size is 2M, while that will get bad performace
on conv3x3s1 or gemm, set to 1M or 512K*/
1024 * 1024);
return true;
} else if (hardware_name.find("MSM8996") != std::string::npos) { // 820
cpu_info->compute_core_num_ = 4;
cpu_info->core_ids_ = {0, 1, 2, 3};
cpu_info->big_core_ids_ = {2, 3};
cpu_info->little_core_ids_ = {0, 1};
cpu_info->cluster_ids_ = {1, 1, 0, 0};
set_arch_info(cpu_info, 1, kA72);
set_cache_info(cpu_info, 0, 1, 24 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("SDM660") != std::string::npos ||
hardware_name.find("SDM636") != std::string::npos) { // 660, 636
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA73);
set_cache_info(cpu_info, 0, 2, 64 * 1024, 32 * 1024);
set_cache_info(cpu_info, 1, 1, 1024 * 1024);
return true;
} else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA72, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("MSM8953") != std::string::npos) { // 625
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 1, 1024 * 1024);
return true;
} else if (hardware_name.find("MSM8939") != std::string::npos) { // 615
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3};
cpu_info->little_core_ids_ = {4, 5, 6, 7};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 1, 1, 1, 1};
set_arch_info(cpu_info, 1, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 512 * 1024, 256 * 1024);
return true;
/* MediaTek */
} else if (hardware_name.find("MT6797") !=
std::string::npos) { // X20/X23/X25/X27
cpu_info->compute_core_num_ = 10;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cpu_info->big_core_ids_ = {8, 9};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA72, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("MT6799") != std::string::npos) { // X30
cpu_info->compute_core_num_ = 10;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cpu_info->big_core_ids_ = {8, 9};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
return true;
} else if (hardware_name.find("MT6795") != std::string::npos ||
hardware_name.find("MT6762") != std::string::npos ||
hardware_name.find("MT6755T") != std::string::npos ||
hardware_name.find("MT6755S") != std::string::npos ||
hardware_name.find("MT6753") != std::string::npos ||
hardware_name.find("MT6752") != std::string::npos ||
hardware_name.find("MT6750") != std::string::npos) {
// X10, P22, P15/P18, MT6753, MT6752/MT6752M, MT6750
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
} else if (hardware_name.find("MT6758") != std::string::npos ||
hardware_name.find("MT6757") != std::string::npos ||
hardware_name.find("MT6763") != std::string::npos ||
hardware_name.find("MT6755M") != std::string::npos ||
hardware_name.find("MT6755") !=
std::string::npos) { // P30, P20/P25, P23, P10
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
} else if (hardware_name.find("MT6771") != std::string::npos) { // P60
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
return true;
} else if (hardware_name.find("MT6765") != std::string::npos ||
hardware_name.find("MT6739") != std::string::npos ||
hardware_name.find("MT6738") != std::string::npos ||
hardware_name.find("MT6737") !=
std::string::npos) { // A22, MT6739, MT6738, MT6767
cpu_info->compute_core_num_ = 4;
cpu_info->core_ids_ = {0, 1, 2, 3};
cpu_info->big_core_ids_ = {0, 0, 0, 0};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
}
return false;
}
size_t arm_get_meminfo() {
#ifdef LITE_WITH_ANDROID
// get cpu count from /proc/cpuinfo
FILE* fp = fopen("/proc/meminfo", "rb");
if (!fp) {
return 1;
}
size_t memsize = 0;
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
sscanf(s, "MemTotal: %d kB", &memsize);
}
fclose(fp);
return memsize;
#elif defined(TARGET_IOS)
// to be implemented
printf("not implemented\n");
return 0;
#endif
}
int arm_get_cpucount() {
#ifdef LITE_WITH_ANDROID
// get cpu count from /sys/devices/system/cpu/cpunum/uevent
int max_cpu_count = 20;
int count = 0;
for (int i = 0; i < max_cpu_count; ++i) {
char path[256];
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i);
FILE* fp = fopen(path, "rb");
if (!fp) {
break;
}
count++;
fclose(fp);
}
if (count < 1) {
count = 1;
}
return count;
#elif defined(TARGET_IOS)
int count = 0;
size_t len = sizeof(count);
sysctlbyname("hw.ncpu", &count, &len, NULL, 0);
if (count < 1) {
count = 1;
}
return count;
#else
return 1;
#endif
}
void arm_get_cpu_arch(std::vector<ARMArch>* archs) {
#ifdef LITE_WITH_ANDROID
archs->clear();
//! get CPU ARCH
FILE* fp = fopen("/proc/cpuinfo", "rb");
if (!fp) {
return;
}
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
if (strstr(line, "part") != NULL) {
int arch_id = 0;
sscanf(s, "CPU part\t: %x", &arch_id);
switch (arch_id) {
case 0xd03:
archs->push_back(kA53);
break;
case 0xd05:
archs->push_back(kA55);
break;
case 0xd07:
archs->push_back(kA57);
break;
case 0xd08:
archs->push_back(kA72);
break;
case 0xd09:
archs->push_back(kA73);
break;
case 0xd0a:
archs->push_back(kA75);
break;
case 0x800:
// 835
archs->push_back(kA73);
break;
case 0x205:
// 820
archs->push_back(kA72);
break;
default:
LOG(ERROR) << "unknow type";
archs->push_back(kARMArch_UNKOWN);
}
}
}
fclose(fp);
int cpu_count = arm_get_cpucount();
if (archs->size() < cpu_count) {
for (int i = archs->size(); i < cpu_count; ++i) {
archs->push_back(archs->at(i - 1));
}
}
#endif
#ifdef TARGET_IOS
int cpu_count = arm_get_cpucount();
for (int i = 0; i < cpu_count; ++i) {
archs->push_back(APPLE);
}
#endif
}
#ifdef LITE_WITH_ANDROID
void set_default_cache(DeviceInfo* dev) {
int cpu_count = arm_get_cpucount();
dev->L1_cache_.resize(cpu_count);
dev->L2_cache_.resize(cpu_count);
dev->L3_cache_.resize(cpu_count);
#ifdef TARGET_IOS
for (int i = 0; i < cpu_count; ++i) {
dev->L1_cache_[i] = 64 * 1024;
dev->L2_cache_[i] = 2048 * 1024;
dev->L3_cache_[i] = 0;
}
#else
for (int i = 0; i < cpu_count; ++i) {
dev->L1_cache_[i] = 32 * 1024;
dev->L2_cache_[i] = 512 * 1024;
dev->L3_cache_[i] = 0;
}
#endif
}
std::string arm_get_cpu_name() {
FILE* fp = fopen("/proc/cpuinfo", "rb");
if (!fp) {
return "";
}
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
if (strstr(line, "Hardware") != NULL) {
fclose(fp);
return std::string(line);
}
}
fclose(fp);
return "";
}
int get_max_freq_khz(int cpuid) {
// first try, for all possible cpu
char path[256];
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuid);
FILE* fp = fopen(path, "rb");
if (!fp) {
// second try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state",
cpuid);
fp = fopen(path, "rb");
if (!fp) {
// third try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuid);
fp = fopen(path, "rb");
if (!fp) {
return -1;
}
int max_freq_khz = -1;
fscanf(fp, "%d", &max_freq_khz);
fclose(fp);
return max_freq_khz;
}
}
int max_freq_khz = 0;
while (!feof(fp)) {
int freq_khz = 0;
int nscan = fscanf(fp, "%d %*d", &freq_khz);
if (nscan != 1) {
break;
}
if (freq_khz > max_freq_khz) {
max_freq_khz = freq_khz;
}
}
fclose(fp);
return max_freq_khz;
}
int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector<int>* cpuids,
const std::vector<int>& cpu_freq,
std::vector<int>* cluster_ids) {
if (cpu_count == 0) {
return 0;
}
cpuids->resize(cpu_count);
cluster_ids->resize(cpu_count);
for (int i = 0; i < cpu_count; i++) {
cpuids->at(i) = i;
}
// sort cpuid as big core first
// simple bubble sort
for (int i = 0; i < cpu_count; i++) {
for (int j = i + 1; j < cpu_count; j++) {
if (cpu_freq[i] < cpu_freq[j]) {
// swap
int tmp = cpuids->at(i);
cpuids->at(i) = cpuids->at(j);
cpuids->at(j) = tmp;
}
}
}
// SMP
int mid_max_freq_khz =
(cpu_freq[cpuids->at(0)] + cpu_freq[cpuids->at(cpu_count - 1)]) / 2;
for (int i = 0; i < cpu_count; i++) {
cpuids->at(i) = i;
if (cpu_freq[i] >= mid_max_freq_khz) {
cluster_ids->at(i) = 0;
} else {
cluster_ids->at(i) = 1;
}
}
return 0;
}
int check_online(const std::vector<int>& core_ids) {
if (core_ids.size() == 0) {
return 0;
}
char path[256];
int online = 1;
for (int i = 0; i < core_ids.size(); ++i) {
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online",
core_ids[i]);
FILE* fp = fopen(path, "rb");
if (!fp) {
return 0;
}
int cur_online = 0;
fscanf(fp, "%d", &cur_online);
online &= cur_online;
fclose(fp);
}
return online;
}
int set_sched_affinity(const std::vector<int>& cpuids) {
// #define CPU_SETSIZE 1024
// #define __NCPUBITS (8 * sizeof (unsigned long))
// typedef struct
// {
// unsigned long __bits[CPU_SETSIZE / __NCPUBITS];
// } cpu_set_t;
// set affinity for thread
#ifdef __GLIBC__
pid_t pid = syscall(SYS_gettid);
#else
pid_t pid = gettid();
#endif
cpu_set_t mask;
CPU_ZERO(&mask);
for (int i = 0; i < cpuids.size(); i++) {
CPU_SET(cpuids[i], &mask);
}
int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask);
if (syscallret) {
LOG(ERROR) << "syscall error " << syscallret;
return -1;
}
return 0;
}
#endif // LITE_WITH_ANDROID
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
#ifdef LITE_WITH_ANDROID
#include <sys/syscall.h>
#include <unistd.h>
#endif
#if __APPLE__
#include "TargetConditionals.h"
#if TARGET_OS_IPHONE
#include <mach/machine.h>
#include <sys/sysctl.h>
#include <sys/types.h>
#endif // TARGET_OS_IPHONE
#endif // __APPLE__
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
typedef enum {
LITE_POWER_HIGH = 0,
LITE_POWER_LOW = 1,
LITE_POWER_FULL = 2,
LITE_POWER_NO_BIND = 3,
LITE_POWER_RAND_HIGH = 4,
LITE_POWER_RAND_LOW = 5
} PowerMode;
typedef enum {
kAPPLE = 0,
kA53 = 53,
kA55 = 55,
kA57 = 57,
kA72 = 72,
kA73 = 73,
kA75 = 75,
kA76 = 76,
kARMArch_UNKOWN = -1
} ARMArch;
class DeviceInfo {
public:
int idx_;
int max_freq_;
int min_freq_;
int generate_arch_;
int compute_core_num_;
int max_memory_;
int sharemem_size_;
std::string device_name_;
std::string compute_ability_;
std::vector<int> L1_cache_;
std::vector<int> L2_cache_;
std::vector<int> L3_cache_;
std::vector<int> core_ids_;
std::vector<int> big_core_ids_;
std::vector<int> little_core_ids_;
std::vector<int> cluster_ids_;
std::vector<ARMArch> archs_;
static DeviceInfo& Global() {
static auto* x = new DeviceInfo;
return *x;
}
static void init_info() {
auto& info = Global();
get_info(&info);
}
private:
DeviceInfo() = default;
static void get_info(DeviceInfo* dev);
};
size_t arm_get_meminfo();
int arm_get_cpucount();
void arm_get_cpu_arch(std::vector<ARMArch>* archs);
bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name);
#ifdef LITE_WITH_ANDROID
void set_default_cache(DeviceInfo* dev);
std::string arm_get_cpu_name();
int get_max_freq_khz(int cpuid);
int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector<int>* cpuids,
const std::vector<int>& cpu_freq,
std::vector<int>* cluster_ids);
int check_online(const std::vector<int>& core_ids);
int set_sched_affinity(const std::vector<int>& cpuids);
#endif // LITE_WITH_ANDROID
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
...@@ -44,7 +44,7 @@ class KernelBase { ...@@ -44,7 +44,7 @@ class KernelBase {
virtual void Run() = 0; virtual void Run() = 0;
void SetContext(std::unique_ptr<KernelContext>&& ctx) { void SetContext(std::unique_ptr<KernelContext>&& ctx) {
context_ = std::move(ctx); ctx_ = std::move(ctx);
} }
template <typename T> template <typename T>
void SetParam(T param) { void SetParam(T param) {
...@@ -86,7 +86,7 @@ class KernelBase { ...@@ -86,7 +86,7 @@ class KernelBase {
virtual TargetType target() const = 0; virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0; virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0; virtual DataLayoutType layout() const = 0;
const KernelContext* context() const { return context_.get(); } const KernelContext* context() const { return ctx_.get(); }
virtual std::string name() const = 0; virtual std::string name() const = 0;
// Short human-readable document. // Short human-readable document.
...@@ -134,7 +134,7 @@ class KernelBase { ...@@ -134,7 +134,7 @@ class KernelBase {
void Torch() {} void Torch() {}
protected: protected:
std::unique_ptr<KernelContext> context_; std::unique_ptr<KernelContext> ctx_;
mutable operators::param_t param_; mutable operators::param_t param_;
// The corresponding op type. // The corresponding op type.
std::string op_type_{}; std::string op_type_{};
...@@ -152,9 +152,6 @@ template <TargetType Target, PrecisionType Precision, ...@@ -152,9 +152,6 @@ template <TargetType Target, PrecisionType Precision,
DataLayoutType DataLayout = DataLayoutType::kNCHW> DataLayoutType DataLayout = DataLayoutType::kNCHW>
class KernelLite : public KernelBase { class KernelLite : public KernelBase {
public: public:
// Set runtime context.
void SetContext(std::unique_ptr<KernelContext>&& ctx) { ctx_ = ctx; }
// Run the kernel. // Run the kernel.
virtual void Run() { CHECK(false) << "Not Implemented"; } virtual void Run() { CHECK(false) << "Not Implemented"; }
...@@ -168,9 +165,6 @@ class KernelLite : public KernelBase { ...@@ -168,9 +165,6 @@ class KernelLite : public KernelBase {
KernelLite() = default; KernelLite() = default;
virtual ~KernelLite() = default; virtual ~KernelLite() = default;
protected:
std::unique_ptr<KernelContext> ctx_;
}; };
template <TargetType Target, PrecisionType Precision, DataLayoutType DataLayout> template <TargetType Target, PrecisionType Precision, DataLayoutType DataLayout>
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <functional> // for multiplies
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
...@@ -40,6 +41,10 @@ class DDimLite : public DDimBase<DDimLite> { ...@@ -40,6 +41,10 @@ class DDimLite : public DDimBase<DDimLite> {
size_t size() const { return data_.size(); } size_t size() const { return data_.size(); }
bool empty() const { return data_.empty(); } bool empty() const { return data_.empty(); }
value_type product() const {
return std::accumulate(std::begin(data_), std::end(data_), 1,
std::multiplies<value_type>());
}
const std::vector<value_type> &data() const { return data_; } const std::vector<value_type> &data() const { return data_; }
private: private:
...@@ -61,8 +66,10 @@ class TensorLite : public TensorBase<TensorLite> { ...@@ -61,8 +66,10 @@ class TensorLite : public TensorBase<TensorLite> {
} }
void Resize(const DDimLite &ddim) { dims_ = ddim; } void Resize(const DDimLite &ddim) { dims_ = ddim; }
void Resize(const std::vector<int64_t> &x) { dims_ = DDimLite(x); }
const DDimLite &dims() const { return dims_; } const DDimLite &dims() const { return dims_; }
int64_t numel() const { return dims_.product(); }
const LoD &lod() const { return lod_; } const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; } LoD *mutable_lod() { return &lod_; }
......
...@@ -32,7 +32,6 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -32,7 +32,6 @@ class RuntimeContextAssignPass : public StmtPass {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsStmt(); auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) { switch (inst.picked_kernel().target()) {
case TARGET(kHost): case TARGET(kHost):
case TARGET(kX86): case TARGET(kX86):
...@@ -42,6 +41,11 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -42,6 +41,11 @@ class RuntimeContextAssignPass : public StmtPass {
case TARGET(kCUDA): case TARGET(kCUDA):
inst.picked_kernel().SetContext(NewCudaContext()); inst.picked_kernel().SetContext(NewCudaContext());
break; break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
inst.picked_kernel().SetContext(NewARMContext());
break;
#endif #endif
default: default:
LOG(FATAL) << "unsupported target " LOG(FATAL) << "unsupported target "
...@@ -54,9 +58,18 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -54,9 +58,18 @@ class RuntimeContextAssignPass : public StmtPass {
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<HostContext>(); ctx->As<HostContext>();
// Some initialization here. // Some initialization here.
return ctx; return ctx;
} }
#ifdef LITE_WITH_ARM
std::unique_ptr<KernelContext> NewARMContext() {
DeviceInfo::init_info();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
return ctx;
}
#endif
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() { std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
...@@ -66,9 +79,7 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -66,9 +79,7 @@ class RuntimeContextAssignPass : public StmtPass {
cuda.blas_fp32 = cublas_fp32_; cuda.blas_fp32 = cublas_fp32_;
return ctx; return ctx;
} }
#endif
#ifdef LITE_WITH_CUDA
void InitCudaBlas() { void InitCudaBlas() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>(); cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
} }
......
message(STATUS "add lite kernels") message(STATUS "add lite kernels")
set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite}) set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite})
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(arm) add_subdirectory(arm)
add_subdirectory(cuda) add_subdirectory(cuda)
......
...@@ -4,11 +4,13 @@ endif() ...@@ -4,11 +4,13 @@ endif()
message(STATUS "compile with lite ARM kernels") message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm eigen3)
set(arm_kernels set(arm_kernels
fc_compute_arm fc_compute_arm
relu_compute_arm relu_compute_arm
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <Eigen/Core> #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
...@@ -22,24 +22,42 @@ namespace lite { ...@@ -22,24 +22,42 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
// NOTE should use pure std C++ implementation.
void FcCompute::Run() { void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
CHECK_GE(param.input->dims().size(), 2UL); CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen( const auto* i_data = param.input->data<float>();
param.input->data<float>(), // x const auto* w_data = param.w->data<float>();
param.input->dims().Slice(0, param.in_num_col_dims).production(), const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
param.input->dims() auto* o_data = param.output->mutable_data<float>();
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(), int x_h = x_dims.Slice(0, param.in_num_col_dims).production();
param.w->data<float>(), // w int x_w = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
param.w->dims()[1], // w_w int n = w_dims[1];
param.w->dims()[0], // w_h CHECK_EQ(x_w, static_cast<int>(w_dims[0]));
param.bias->data<float>(), // b auto& ctx = this->ctx_->template As<ARMContext>();
param.output->mutable_data<float>()); if (x_h > 1) {
float* packed_in = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_in, i_data, x_w, 0, x_h, 0, x_w, false,
&ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n,
x_w, false, false, false, &ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n);
lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n);
}
} else {
// use sgemmv
// sgemv((const float*)weights, (const float*)din, (float*)dout,
// false, n, x_w, _param->_flag_bias, (float*)bias, false);
}
} }
TargetType FcCompute::target() const { return TARGET(kARM); } TargetType FcCompute::target() const { return TARGET(kARM); }
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/fc_op.h" #include "paddle/fluid/lite/operators/fc_op.h"
...@@ -34,52 +33,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -34,52 +33,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FcCompute() = default; virtual ~FcCompute() = default;
}; };
template <typename T>
void fc_compute_eigen(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_h);
Out = X * W.transpose();
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_h);
Out = Out.array().rowwise() + B.transpose().array();
}
}
template <typename T>
__attribute__((optimize("unroll-loops"))) //
T dot(const T* x, const T* y, int dim) {
T out{};
for (int i = 0; i < dim; i++) {
out += x[i] * y[i];
}
return out;
}
template <typename T>
void fc_compute_naive(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_w);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_h * sizeof(T));
for (int r = 0; r < x_h; r++) {
for (int c = 0; c < w_h; c++) {
out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c];
}
}
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
...@@ -22,60 +23,79 @@ namespace lite { ...@@ -22,60 +23,79 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
TEST(fc_compute_naive, test) { TEST(fc_arm, retrive_op) {
lite::Tensor x, w, b, out, out1; auto fc =
const int batch_size = 2; KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_arm, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
}
TEST(fc_arm, compare_test) {
lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2;
x.Resize({batch_size, 3}); x.Resize({batch_size, 3});
w.Resize({4, 3}); w.Resize({3, 4});
b.Resize({1, 4}); b.Resize({1, 4});
out.Resize({batch_size, 4}); out.Resize({batch_size, 4});
out1.Resize({batch_size, 4}); ref.Resize({batch_size, 4});
auto x_data = x.mutable_data<float>(); auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>(); auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>(); auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>(); auto out_data = out.mutable_data<float>();
auto out_data1 = out1.mutable_data<float>(); auto ref_data = ref.mutable_data<float>();
for (int i = 0; i < product(x.dims()); i++) x_data[i] = i; for (int64_t i = 0; i < x.dims().product(); i++) {
for (int i = 0; i < product(w.dims()); i++) w_data[i] = i; x_data[i] = static_cast<float>(i);
for (int i = 0; i < product(b.dims()); i++) b_data[i] = i; }
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().product(); i++) {
b_data[i] = static_cast<float>(i);
}
fc_compute_naive(x_data, 3, batch_size, // // TODO(TJ): enable bias soon
b_data = nullptr;
lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, // w_data, 3, 4, //
b_data, out_data); b_data, ref_data);
fc_compute_eigen(x_data, 3, batch_size, //
w_data, 3, 4, //
b_data, out_data1);
for (int i = 0; i < product(out.dims()); i++) {
EXPECT_NEAR(out_data[0], out_data1[0], 1e-6);
}
}
TEST(fc_arm, init) { // fc compute kernel
FcCompute fc; FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat)); operators::FcParam param;
ASSERT_EQ(fc.target(), TARGET(kARM));
}
TEST(fc_arm, algorithm) { param.in_num_col_dims = 1;
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>; param.input = &x;
using matrix_map_t = Eigen::Map<matrix_t>; param.w = &w;
param.bias = nullptr;
param.output = &out;
param.in_mat_dims = x.dims();
// dim 10, 20 DeviceInfo::init_info();
std::vector<float> input(10 * 20); std::unique_ptr<KernelContext> ctx(new KernelContext);
std::vector<float> w(20 * 20); ctx->As<ARMContext>();
std::vector<float> output(10 * 20); fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
Eigen::Map<const matrix_t> input_mat(input.data(), 10, 20); VLOG(3) << "output vs ref";
Eigen::Map<const matrix_t> weight_mat(w.data(), 20, 20); for (int i = 0; i < out.dims().product(); i++) {
matrix_map_t output_mat(output.data(), 10, 20); VLOG(3) << out_data[i] << " vs " << ref_data[i];
}
output_mat = weight_mat.transpose() * input_mat; for (int i = 0; i < out.dims().product(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
} }
TEST(fc_arm, compute) { TEST(fc_arm, num_col_dims) {
FcCompute fc; FcCompute fc;
operators::FcParam param; operators::FcParam param;
...@@ -84,20 +104,28 @@ TEST(fc_arm, compute) { ...@@ -84,20 +104,28 @@ TEST(fc_arm, compute) {
lite::Tensor bias; lite::Tensor bias;
lite::Tensor output; lite::Tensor output;
x.Resize(DDim(std::vector<int64_t>({1, 10, 20}))); x.Resize({1, 2, 3});
w.Resize(DDim(std::vector<int64_t>({20, 20}))); w.Resize({3, 4});
bias.Resize(DDim(std::vector<int64_t>({1, 10}))); bias.Resize({1, 4});
output.Resize(DDim(std::vector<int64_t>({10, 20}))); output.Resize({2, 4});
auto* x_data = x.mutable_data<float>(); auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>(); auto* w_data = w.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>(); auto* bias_data = bias.mutable_data<float>();
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
for (int i = 0; i < 10 * 20; i++) x_data[i] = i; for (int64_t i = 0; i < x.dims().product(); i++) {
for (int i = 0; i < 20 * 20; i++) w_data[i] = i; x_data[i] = static_cast<float>(i);
for (int i = 0; i < 10; i++) bias_data[i] = i; }
for (int i = 0; i < 10 * 20; i++) output_data[i] = 0; for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < bias.dims().product(); i++) {
bias_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < output.dims().product(); i++) {
output_data[i] = static_cast<float>(i);
}
param.in_num_col_dims = 2; param.in_num_col_dims = 2;
param.input = &x; param.input = &x;
...@@ -106,20 +134,13 @@ TEST(fc_arm, compute) { ...@@ -106,20 +134,13 @@ TEST(fc_arm, compute) {
param.output = &output; param.output = &output;
param.in_mat_dims = x.dims(); param.in_mat_dims = x.dims();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
DeviceInfo::init_info();
fc.SetParam(param); fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run(); fc.Run();
LOG(INFO) << "x";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << x_data[i];
LOG(INFO) << "output:";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << output.data<float>()[i];
}
TEST(fc, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_TRUE(fc);
} }
} // namespace arm } // namespace arm
......
...@@ -35,8 +35,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -35,8 +35,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void Run() override {
CHECK(context_) << "running context should be set first"; CHECK(ctx_) << "running context should be set first";
auto& context = context_->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
CHECK(context.blas_fp32) << "blas should init first"; CHECK(context.blas_fp32) << "blas should init first";
/* /*
auto& blas = *context.blas_fp32; auto& blas = *context.blas_fp32;
......
...@@ -60,7 +60,7 @@ class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -60,7 +60,7 @@ class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::ActivationParam; using param_t = operators::ActivationParam;
void Run() override { void Run() override {
auto& context = context_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>(); auto& param = *param_.get_mutable<operators::ActivationParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context);
...@@ -82,7 +82,7 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -82,7 +82,7 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::ActivationGradParam; using param_t = operators::ActivationGradParam;
void Run() override { void Run() override {
auto& context = context_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationGradParam>(); auto& param = *param_.get_mutable<operators::ActivationGradParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context);
param.X_grad->template mutable_data<T>(); param.X_grad->template mutable_data<T>();
......
...@@ -38,7 +38,7 @@ class ElementwiseSubCompute ...@@ -38,7 +38,7 @@ class ElementwiseSubCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context);
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
......
...@@ -22,7 +22,9 @@ function cmake_arm { ...@@ -22,7 +22,9 @@ function cmake_arm {
-DLITE_WITH_CUDA=OFF \ -DLITE_WITH_CUDA=OFF \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \ -DWITH_TESTING=ON \
-DWITH_MKL=OFF \
-DWITH_MKLDNN=OFF -DWITH_MKLDNN=OFF
make cxx_api_lite_bin -j8
} }
function build { function build {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册