diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index 896f6c8d33a8665c4c94786dd08af1a097942608..166c04c000d345eb39822d1d67321a1c6a05e9a5 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -59,6 +59,12 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_ARM +thread_local lite_api::PowerMode DeviceInfo::mode_; +thread_local ARMArch DeviceInfo::arch_; +thread_local int DeviceInfo::mem_size_; +thread_local std::vector DeviceInfo::active_ids_; +thread_local TensorLite DeviceInfo::workspace_; +thread_local int64_t DeviceInfo::count_ = 0; #ifdef TARGET_IOS const int DEFAULT_L1_CACHE_SIZE = 64 * 1024; diff --git a/lite/core/device_info.h b/lite/core/device_info.h index 81c0ac4bf9a9a134de448efa92ac0cb2c1a06454..1ff8b896a70dc538d2486a24db2625c7b62c600a 100644 --- a/lite/core/device_info.h +++ b/lite/core/device_info.h @@ -79,7 +79,6 @@ class DeviceInfo { int core_num_; std::vector max_freqs_; std::vector min_freqs_; - int mem_size_; std::string dev_name_; std::vector L1_cache_; @@ -94,14 +93,15 @@ class DeviceInfo { std::vector fp16_; std::vector dot_; - ARMArch arch_; // LITE_POWER_HIGH stands for using big cores, // LITE_POWER_LOW stands for using small core, // LITE_POWER_FULL stands for using all cores - lite_api::PowerMode mode_; - std::vector active_ids_; - TensorLite workspace_; - int64_t count_{0}; + static thread_local lite_api::PowerMode mode_; + static thread_local ARMArch arch_; + static thread_local int mem_size_; + static thread_local std::vector active_ids_; + static thread_local TensorLite workspace_; + static thread_local int64_t count_; void SetDotInfo(int argc, ...); void SetFP16Info(int argc, ...); @@ -119,7 +119,6 @@ class DeviceInfo { DeviceInfo() = default; }; - #endif // LITE_WITH_ARM template diff --git a/lite/kernels/arm/conv_direct.cc b/lite/kernels/arm/conv_direct.cc index ae8c1d1b9aa4e1e3e79c68116d91a0d0c1e9b1ab..ccf36391e7b252f3d04b83e538ef51f0e45ca67e 100644 --- a/lite/kernels/arm/conv_direct.cc +++ b/lite/kernels/arm/conv_direct.cc @@ -20,15 +20,10 @@ namespace kernels { namespace arm { template <> -void DirectConv::ReInitWhenNeeded() { - auto& param = this->template Param(); - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - if (last_shape_ == x_dims) { - return; - } +void DirectConv::Run() { + auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + // extend workspace if (param.strides[0] == 2) { ctx.ExtendWorkspace( lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx)); @@ -36,12 +31,7 @@ void DirectConv::ReInitWhenNeeded() { ctx.ExtendWorkspace( lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx)); } -} -template <> -void DirectConv::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); const auto* i_data = param.x->data(); const auto* w_data = weights_.data(); const auto* b_data = param.bias ? param.bias->data() : nullptr; @@ -89,9 +79,6 @@ void DirectConv::Run() { } } -template <> -void DirectConv::ReInitWhenNeeded() {} - template <> void DirectConv::Run() { auto& param = this->Param(); @@ -148,9 +135,6 @@ void DirectConv::Run() { } } -template <> -void DirectConv::ReInitWhenNeeded() {} - template <> void DirectConv::Run() { auto& param = this->Param(); diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h index 24c934e14b5ae5e7de4d089da1611deb0e77fefb..cd90c4d6c5adb0d33fbd8082db02cecc9f76d9fb 100644 --- a/lite/kernels/arm/conv_direct.h +++ b/lite/kernels/arm/conv_direct.h @@ -156,7 +156,6 @@ class DirectConv : public KernelLite { auto x_dims = param.x->dims(); auto w_dims = param.filter->dims(); auto o_dims = param.output->dims(); - last_shape_ = x_dims; int ic = x_dims[1]; int oc = o_dims[1]; @@ -179,12 +178,10 @@ class DirectConv : public KernelLite { w_scale_); } - virtual void ReInitWhenNeeded(); virtual void Run(); /// todo, support inplace weights transform protected: - DDim last_shape_; Tensor weights_; Tensor bias_; bool flag_trans_weights_{false}; diff --git a/lite/kernels/arm/conv_gemmlike.cc b/lite/kernels/arm/conv_gemmlike.cc index 56dc72f2d6bbc331ccc14305d502f11cf4f27609..4b1f57886955bc0fa006d708d04a191c0df768e3 100644 --- a/lite/kernels/arm/conv_gemmlike.cc +++ b/lite/kernels/arm/conv_gemmlike.cc @@ -85,6 +85,7 @@ template <> void GemmLikeConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); auto weights = param.filter->data(); if (flag_trans_weights_) { weights = weights_.data(); @@ -120,6 +121,7 @@ template <> void GemmLikeConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); auto weights = param.filter->data(); if (flag_trans_weights_) { weights = weights_.data(); @@ -179,6 +181,7 @@ template <> void GemmLikeConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); auto weights = param.filter->data(); if (flag_trans_weights_) { weights = weights_.data(); diff --git a/lite/kernels/arm/conv_gemmlike.h b/lite/kernels/arm/conv_gemmlike.h index 0f1213390b1febcc721cefc8a4005184dd00d3ec..e00b8de6f4a66dfea91e8806821ba7cf3a9aa62b 100644 --- a/lite/kernels/arm/conv_gemmlike.h +++ b/lite/kernels/arm/conv_gemmlike.h @@ -72,7 +72,7 @@ class GemmLikeConv : public KernelLite { } else { //! im2col gemmlike conv flag_1x1gemm_ = false; - ctx.ExtendWorkspace(k * n * sizeof(float)); + workspace_size_ = k * n * sizeof(float); } if (!flag_trans_weights_ && n > 1) { lite::arm::math::trans_gemm_weights( @@ -97,6 +97,7 @@ class GemmLikeConv : public KernelLite { bool flag_trans_bias_{false}; Tensor weights_; Tensor bias_; + int workspace_size_{0}; }; } // namespace arm diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index 9fca00ad6b01b0c420d9a0d3ad0f712604a4a441..5a18499c85d682e0983493869e7d54de81641a99 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -40,13 +40,13 @@ void Conv2DTransposeCompute::PrepareForRun() { int group = param.groups; // deconv weights layout: chin * chout * kh * kw - auto& ctx = this->ctx_->template As(); int m = chout * kw * kh / group; int n = hin * win; int k = chin / group; - ctx.ExtendWorkspace(group * m * n * sizeof(float)); + workspace_size_ = group * m * n * sizeof(float); + auto& ctx = this->ctx_->template As(); lite::Tensor tmp_weights; lite::arm::math::prepackA( &tmp_weights, *(param.filter), 1.f, m, k, group, true, &ctx); @@ -57,6 +57,8 @@ void Conv2DTransposeCompute::PrepareForRun() { } void Conv2DTransposeCompute::Run() { + auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); auto& param = this->Param(); auto x_dims = param.x->dims(); auto o_dims = param.output->dims(); @@ -80,7 +82,6 @@ void Conv2DTransposeCompute::Run() { int group_size_in = win * hin * chin / group; int group_size_out = wout * hout * chout / group; int group_size_coldata = m * n; - auto& ctx = this->ctx_->template As(); int hblock = lite::arm::math::get_hblock(&ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int group_size_weights = ((m_roundup * k + 15) / 16) * 16; diff --git a/lite/kernels/arm/conv_transpose_compute.h b/lite/kernels/arm/conv_transpose_compute.h index 8b87c7cfad9276c867d21db29c2af14e1e3ea86d..7b781cdd5253205c4eb21b1ddcfa5187110581b5 100644 --- a/lite/kernels/arm/conv_transpose_compute.h +++ b/lite/kernels/arm/conv_transpose_compute.h @@ -32,6 +32,9 @@ class Conv2DTransposeCompute void Run() override; ~Conv2DTransposeCompute() = default; + + protected: + int workspace_size_{0}; }; } // namespace arm diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index f6e73a0a59f8dbf9f0549a4732daaa53b89b9666..d1b8d8a48ecd7d564947486ee2938d6b630c41e5 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -46,8 +46,7 @@ void WinogradConv::ReInitWhenNeeded() { int max_ch = ic > oc ? ic : oc; const int n_wino = size_tile; - ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * - sizeof(float)); + workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float); last_shape_ = x_dims; } @@ -76,8 +75,7 @@ void WinogradConv::PrepareForRun() { int hblock = lite::arm::math::get_hblock(&ctx); int m_round = hblock * ((m_wino + hblock - 1) / hblock); weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); - ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * - sizeof(float)); + workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float); auto weights_wino = static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); @@ -106,6 +104,9 @@ template <> void WinogradConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + // extend workspace + ctx.ExtendWorkspace(workspace_size_); + const auto* i_data = param.x->data(); const auto* w_data = weights_.data(); const auto* b_data = param.bias ? param.bias->data() : nullptr; diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h index 8b6de0af5ed359015e15515b559bfaf754d4c3f9..33f0edc017adca477b2e71964efdcaddb0ca3a08 100644 --- a/lite/kernels/arm/conv_winograd.h +++ b/lite/kernels/arm/conv_winograd.h @@ -39,6 +39,7 @@ class WinogradConv : public KernelLite { using param_t = operators::ConvParam; Tensor weights_; DDim last_shape_; + int workspace_size_{0}; }; } // namespace arm