未验证 提交 19f8ac5c 编写于 作者: T TianXiaogang 提交者: GitHub

fix:fix deviceinfo worksapce to tls

mv deviceinfo.workspace and other relative member to thread_local_storage
上级 ab508fcc
......@@ -59,6 +59,12 @@ namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
thread_local lite_api::PowerMode DeviceInfo::mode_;
thread_local ARMArch DeviceInfo::arch_;
thread_local int DeviceInfo::mem_size_;
thread_local std::vector<int> DeviceInfo::active_ids_;
thread_local TensorLite DeviceInfo::workspace_;
thread_local int64_t DeviceInfo::count_ = 0;
#ifdef TARGET_IOS
const int DEFAULT_L1_CACHE_SIZE = 64 * 1024;
......
......@@ -79,7 +79,6 @@ class DeviceInfo {
int core_num_;
std::vector<int> max_freqs_;
std::vector<int> min_freqs_;
int mem_size_;
std::string dev_name_;
std::vector<int> L1_cache_;
......@@ -94,14 +93,15 @@ class DeviceInfo {
std::vector<bool> fp16_;
std::vector<bool> dot_;
ARMArch arch_;
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core,
// LITE_POWER_FULL stands for using all cores
lite_api::PowerMode mode_;
std::vector<int> active_ids_;
TensorLite workspace_;
int64_t count_{0};
static thread_local lite_api::PowerMode mode_;
static thread_local ARMArch arch_;
static thread_local int mem_size_;
static thread_local std::vector<int> active_ids_;
static thread_local TensorLite workspace_;
static thread_local int64_t count_;
void SetDotInfo(int argc, ...);
void SetFP16Info(int argc, ...);
......@@ -119,7 +119,6 @@ class DeviceInfo {
DeviceInfo() = default;
};
#endif // LITE_WITH_ARM
template <TargetType Type>
......
......@@ -20,15 +20,10 @@ namespace kernels {
namespace arm {
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
// extend workspace
if (param.strides[0] == 2) {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx));
......@@ -36,12 +31,7 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx));
}
}
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<float>();
const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
......@@ -89,9 +79,6 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -148,9 +135,6 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......
......@@ -156,7 +156,6 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
last_shape_ = x_dims;
int ic = x_dims[1];
int oc = o_dims[1];
......@@ -179,12 +178,10 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
w_scale_);
}
virtual void ReInitWhenNeeded();
virtual void Run();
/// todo, support inplace weights transform
protected:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
......
......@@ -85,6 +85,7 @@ template <>
void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
auto weights = param.filter->data<float>();
if (flag_trans_weights_) {
weights = weights_.data<float>();
......@@ -120,6 +121,7 @@ template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
auto weights = param.filter->data<int8_t>();
if (flag_trans_weights_) {
weights = weights_.data<int8_t>();
......@@ -179,6 +181,7 @@ template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
auto weights = param.filter->data<int8_t>();
if (flag_trans_weights_) {
weights = weights_.data<int8_t>();
......
......@@ -72,7 +72,7 @@ class GemmLikeConv : public KernelLite<TARGET(kARM), Ptype> {
} else {
//! im2col gemmlike conv
flag_1x1gemm_ = false;
ctx.ExtendWorkspace(k * n * sizeof(float));
workspace_size_ = k * n * sizeof(float);
}
if (!flag_trans_weights_ && n > 1) {
lite::arm::math::trans_gemm_weights<Ptype>(
......@@ -97,6 +97,7 @@ class GemmLikeConv : public KernelLite<TARGET(kARM), Ptype> {
bool flag_trans_bias_{false};
Tensor weights_;
Tensor bias_;
int workspace_size_{0};
};
} // namespace arm
......
......@@ -40,13 +40,13 @@ void Conv2DTransposeCompute::PrepareForRun() {
int group = param.groups;
// deconv weights layout: chin * chout * kh * kw
auto& ctx = this->ctx_->template As<ARMContext>();
int m = chout * kw * kh / group;
int n = hin * win;
int k = chin / group;
ctx.ExtendWorkspace(group * m * n * sizeof(float));
workspace_size_ = group * m * n * sizeof(float);
auto& ctx = this->ctx_->template As<ARMContext>();
lite::Tensor tmp_weights;
lite::arm::math::prepackA(
&tmp_weights, *(param.filter), 1.f, m, k, group, true, &ctx);
......@@ -57,6 +57,8 @@ void Conv2DTransposeCompute::PrepareForRun() {
}
void Conv2DTransposeCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto o_dims = param.output->dims();
......@@ -80,7 +82,6 @@ void Conv2DTransposeCompute::Run() {
int group_size_in = win * hin * chin / group;
int group_size_out = wout * hout * chout / group;
int group_size_coldata = m * n;
auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(&ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_weights = ((m_roundup * k + 15) / 16) * 16;
......
......@@ -32,6 +32,9 @@ class Conv2DTransposeCompute
void Run() override;
~Conv2DTransposeCompute() = default;
protected:
int workspace_size_{0};
};
} // namespace arm
......
......@@ -46,8 +46,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float);
last_shape_ = x_dims;
}
......@@ -76,8 +75,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float);
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
......@@ -106,6 +104,9 @@ template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
// extend workspace
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->data<float>();
const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
......
......@@ -39,6 +39,7 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
using param_t = operators::ConvParam;
Tensor weights_;
DDim last_shape_;
int workspace_size_{0};
};
} // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册