diff --git a/src/operators/math/gemm/cpu_info.h b/src/framework/context.h similarity index 51% rename from src/operators/math/gemm/cpu_info.h rename to src/framework/context.h index 54975797c782be1964c562f38fd12edbcd6a2f0e..d38e1e3b5625b9151cc0c8c4ec41ce66080dd545 100644 --- a/src/operators/math/gemm/cpu_info.h +++ b/src/framework/context.h @@ -14,42 +14,66 @@ limitations under the License. */ #pragma once +#if _OPENMP +#include +#endif + #define MOBILE_MAX_CPU_NUM 8 namespace paddle_mobile { -namespace operators { -namespace math { +namespace framework { -struct CPUInfo { +struct CPUContext { private: - CPUInfo() { + CPUContext() : num_cpus(4), num_threads(1) { // TODO(hjchen2) - num_cpus = 4; for (int i = 0; i < num_cpus; ++i) { - cpu_frequency[i] = 2400; // 2400 MHz - max_cpu_frequency[i] = 2400; // 2400 MHz + cpu_frequencies[i] = 2400; // 2400 MHz + max_cpu_frequencies[i] = 2400; // 2400 MHz } // L1_cache = 32000; // 32K L1_cache = 32 * 1024; L2_cache = 2000000; // 2M // L2_cache = 512000; } - virtual ~CPUInfo() {} public: - static CPUInfo* Info() { - static CPUInfo* ctx = new CPUInfo; + void set_num_threads(int threads) { +#if _ONENMP + omp_set_num_threads(threads); + if (threads <= omp_get_max_threads()) { + num_threads = threads; + } else { + num_threads = omp_get_max_threads(); + } +#endif + num_threads = (num_threads > 1) ? num_threads : 1; + } + + virtual ~CPUContext() {} + + public: + static CPUContext* Context() { + static CPUContext* ctx = new CPUContext; return ctx; } int num_cpus; - int cpu_frequency[MOBILE_MAX_CPU_NUM]; - int max_cpu_frequency[MOBILE_MAX_CPU_NUM]; + int num_threads; + int cpu_frequencies[MOBILE_MAX_CPU_NUM]; + int max_cpu_frequencies[MOBILE_MAX_CPU_NUM]; int L1_cache; int L2_cache; }; -} // namespace math -} // namespace operators +inline void set_global_num_threads(int threads) { + CPUContext::Context()->set_num_threads(threads); +} + +inline int get_global_num_threads() { + return CPUContext::Context()->num_threads; +} + +} // namespace framework } // namespace paddle_mobile diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index a2047e845a54305adef6847b632a17c397cd3002..a15c0e6b4e73e6132c5118379dc7ffb5ec75f0a3 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "common/enforce.h" #include "common/log.h" +#include "framework/context.h" #include "framework/framework.pb-c.h" #include "framework/lod_tensor.h" #include "framework/operator.h" @@ -37,6 +38,11 @@ namespace framework { #pragma mark - executor +template +void Executor::SetThreadNum(int threads) { + set_global_num_threads(threads); +} + template Executor::Executor(const Program &program, paddle_mobile::PaddleMobileConfigInternal config, @@ -444,6 +450,9 @@ std::shared_ptr Executor::GetOutput( template PMStatus Executor::Predict() { +#if _OPENMP + omp_set_num_threads(get_global_num_threads()); +#endif #ifdef PADDLE_MOBILE_PROFILE std::vector profile(ops_of_block0_.size()); struct timespec ts; @@ -654,14 +663,18 @@ void Executor::InitNoPersistableMemory( output->Resize(input_tensor.dims()); output->mutable_data(); } + template <> void Executor::SetInput(const Tensor &input, const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); + int index = 0; + if (feed_indices_.find(var_name) != feed_indices_.end()) { + index = feed_indices_.find(var_name)->second; + } + auto *feed_var = program_.scope->Var("feed"); + framework::LoDTensor *target_tensor = + &(feed_var->template GetMutable()->at(index)); - auto *target_tensor = target_var->template GetMutable(); DLOG << "config_.load_when_predict " << config_.load_when_predict; DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized(); DLOG << "target_tensor->dims() " << target_tensor->dims(); @@ -772,7 +785,7 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); @@ -840,7 +853,7 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); diff --git a/src/framework/executor.h b/src/framework/executor.h index 853914c54cb962c570ae2a9751500d3275091499..074bc4179ade271683a5454edf024661732d270d 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -36,6 +36,8 @@ class Executor { paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1, const bool use_optimize = true, const bool lod_mode = false); + void SetThreadNum(int threads); + PMStatus Predict(const std::vector> &inputs); PMStatus Predict( const std::vector> &inputs); diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 6294f6bf467b1c1684d87c51b9a3b04508d56016..ceab4f4aeec4070327cf9fb46b1dc06ce19cd4a5 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -28,9 +28,7 @@ namespace paddle_mobile { template void PaddleMobile::SetThreadNum(int num) { -#ifdef _OPENMP - omp_set_num_threads(num); -#endif + executor_->SetThreadNum(num); } template diff --git a/src/operators/kernel/arm/density_prior_box_kernel.cpp b/src/operators/kernel/arm/density_prior_box_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54ea9805ae3a25975f9df23fa9888fe4eced3e29 --- /dev/null +++ b/src/operators/kernel/arm/density_prior_box_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef DENSITY_PRIORBOX_OP + +#include "operators/kernel/prior_box_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool DensityPriorBoxKernel::Init(DensityPriorBoxParam *param) { + return true; +} + +template <> +void DensityPriorBoxKernel::Compute( + const DensityPriorBoxParam ¶m) { + // TODO(hjchen2) +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // DENSITY_PRIORBOX_OP diff --git a/src/operators/kernel/arm/sequence_pool_kernel.cpp b/src/operators/kernel/arm/sequence_pool_kernel.cpp index 8326c55515f2c8ab480c9efe96b7d105ed03d495..352158b973050c99555a82c0d0f02c318b7702ac 100644 --- a/src/operators/kernel/arm/sequence_pool_kernel.cpp +++ b/src/operators/kernel/arm/sequence_pool_kernel.cpp @@ -21,7 +21,7 @@ limitations under the License. */ #include "common/types.h" #include "operators/kernel/sequence_kernels.h" #include "operators/math/pooling.h" -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ #include #endif // __ARM_NEON__ @@ -44,7 +44,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, if (width == 1) { float max = -std::numeric_limits::max(); int remain_h = height; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ int loop = remain_h >> 2; remain_h = remain_h & 0x3; float32x4_t __max4 = math::vPoolInitq_f32(); @@ -67,11 +67,11 @@ void SequencePoolImpl(const framework::LoDTensor &input, in_ptr += width; int remain_h = height - 1; int remain_w_start = 0; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ remain_w_start = width & 0xfffc; #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ for (int w = 0; w < width; w += 4) { float32x4_t __in = vld1q_f32(in_ptr + w); float32x4_t __out = vld1q_f32(out_ptr + w); @@ -104,7 +104,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, if (width == 1) { float sum = 0.f; int remain_h = height; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ int loop = remain_h >> 2; remain_h = remain_h & 0x3; float32x4_t __sum4 = vdupq_n_f32(0.f); @@ -126,12 +126,12 @@ void SequencePoolImpl(const framework::LoDTensor &input, in_ptr += width; int remain_h = height - 1; int remain_w_start = 0; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ int loop_w = width >> 2; remain_w_start = width & 0xfffc; #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#ifdef __ARM_NEON__ for (int w = 0; w < width - 3; w += 4) { float32x4_t __in = vld1q_f32(in_ptr + w); float32x4_t __out = vld1q_f32(out_ptr + w); diff --git a/src/operators/kernel/prior_box_kernel.h b/src/operators/kernel/prior_box_kernel.h index 921d5901a8f24abab61f7aa94663385d91e597a7..2a0f4e8f0155e74b8f6b3d75022c713df26e91c7 100644 --- a/src/operators/kernel/prior_box_kernel.h +++ b/src/operators/kernel/prior_box_kernel.h @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PRIORBOX_OP - #pragma once #include @@ -26,9 +24,10 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, +#ifdef PRIORBOX_OP +inline void ExpandAspectRatios(const std::vector &input_aspect_ratior, bool flip, - std::vector* output_aspect_ratior) { + std::vector *output_aspect_ratior) { constexpr float epsilon = 1e-6; output_aspect_ratior->clear(); output_aspect_ratior->push_back(1.0f); @@ -50,14 +49,63 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, } } -template -class PriorBoxKernel - : public framework::OpKernelBase> { +DECLARE_KERNEL(PriorBox, PriorBoxParam); +#endif // PRIORBOX_OP + +#ifdef DENSITY_PRIORBOX_OP +template +class DensityPriorBoxParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + + public: + DensityPriorBoxParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + input_ = InputFrom(inputs, *scope); + input_image_ = InputImageFrom(inputs, *scope); + output_boxes_ = OutputBoxesFrom(outputs, *scope); + output_variances_ = OutputVariancesFrom(outputs, *scope); + variances_ = GetAttr>("variances", attrs); + clip_ = GetAttr("clip", attrs); + flatten_to_2d_ = GetAttr("flatten_to_2d", attrs); + step_w_ = GetAttr("step_w", attrs); + step_h_ = GetAttr("step_h", attrs); + offset_ = GetAttr("offset", attrs); + fixed_sizes_ = GetAttr>("fixed_sizes", attrs); + fixed_ratios_ = GetAttr>("fixed_ratios", attrs); + densities_ = GetAttr>("densities", attrs); + } + + const GType *Input() const { return input_; } + const GType *InputImage() const { return input_image_; } + GType *OutputBoxes() const { return output_boxes_; } + GType *OutputVariances() const { return output_variances_; } + const bool Clip() const { return clip_; } + const bool FlattenTo2d() const { return flatten_to_2d_; } + const float StepW() const { return step_w_; } + const float StepH() const { return step_h_; } + const float Offset() const { return offset_; } + const vector &FixedSizes() const { return fixed_sizes_; } + const vector &FixedRatios() const { return fixed_ratios_; } + const vector &Densities() const { return densities_; } + public: - void Compute(const PriorBoxParam& param); - bool Init(PriorBoxParam* param); + GType *input_; + GType *input_image_; + GType *output_boxes_ GType *output_variances_; + bool clip_; + bool flatten_to_2d_; + float step_w_; + float step_h_; + float offset_; + vector fixed_sizes_; + vector fixed_ratios_; + vector densities_; }; + +DECLARE_KERNEL(DensityPriorBox, DensityPriorBoxParam); +#endif // DENSITY_PRIORBOX_OP + } // namespace operators } // namespace paddle_mobile - -#endif diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc index adc375b62913f0ad1105080f8c26b547e96671f3..058b61f1114c664e8e5ac3ae31f85e4186b9fd8a 100644 --- a/src/operators/math/gemm/cblas.cc +++ b/src/operators/math/gemm/cblas.cc @@ -17,7 +17,6 @@ limitations under the License. */ #pragma once #include "operators/math/gemm/cblas.h" -#include "operators/math/gemm/cpu_info.h" #include "operators/math/gemm/executor.h" #include "operators/math/gemm/strategy.h" diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index ddbed3dbdf6a5399b0f945d7da98ed536ee5e4e2..ce43dc0257a01be5f6a55cb12d9b2b77f1a31086 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -19,17 +19,17 @@ limitations under the License. */ #include #endif // #include -// #include +#include #include "common/log.h" +#include "framework/context.h" #include "memory/t_malloc.h" -#include "operators/math/gemm/cpu_info.h" #include "operators/math/gemm/gemm_kernel.h" namespace paddle_mobile { namespace operators { namespace math { -static CPUInfo *info = CPUInfo::Info(); +static framework::CPUContext *g_cpu_ctx = framework::CPUContext::Context(); int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, @@ -70,11 +70,11 @@ class GemmExecutor : public Executor { unsigned int L1_size = 0; unsigned int L2_size = 0; if (M_ > N_) { - L2_size = ResetL1Cache(info->L1_cache, num_threads_, M_, K_); - L1_size = info->L2_cache; + L2_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, M_, K_); + L1_size = g_cpu_ctx->L2_cache; } else { - L1_size = ResetL1Cache(info->L1_cache, num_threads_, N_, K_); - L2_size = info->L2_cache; + L1_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, N_, K_); + L2_size = g_cpu_ctx->L2_cache; } rhs_tile_num_ = L1_size / (K_ * sizeof(Itype)); diff --git a/src/operators/prior_box_op.cpp b/src/operators/prior_box_op.cpp index b2b43f6418e08e56f6b1af0023bc18fc342fb11d..abef57e2bb3809ce280ebb9f1cfce8ef981e5bd5 100644 --- a/src/operators/prior_box_op.cpp +++ b/src/operators/prior_box_op.cpp @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PRIORBOX_OP - #include "operators/prior_box_op.h" #include + namespace paddle_mobile { namespace operators { +#ifdef PRIORBOX_OP template void PriorBoxOp::InferShape() const { auto input_dims = this->param_.Input()->dims(); @@ -44,15 +44,55 @@ void PriorBoxOp::InferShape() const { this->param_.OutputBoxes()->Resize(framework::make_ddim(dim_vec)); this->param_.OutputVariances()->Resize(framework::make_ddim(dim_vec)); } +#endif // PRIORBOX_OP + +#ifdef DENSITY_PRIORBOX_OP +template +void DensityPriorBoxOp::InferShape() const { + auto input_dims = this->param_.Input()->dims(); + auto input_image_dims = this->param_.InputImage()->dims(); + + auto &fixed_sizes = this->param_.FixedSizes(); + auto &fixed_ratios = this->param_.FixedRatios(); + auto &densities = this->param_.Densities(); + bool flatten = this->param_.FlattenTo2d(); + + size_t num_priors = 0; + for (size_t i = 0; i < densities.size(); ++i) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + if (!flatten) { + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_priors; + dim_vec[3] = 4; + this->param_.OutputBoxes()->Resize(framework::make_ddim(dim_vec)); + this->param_.OutputVariances()->Resize(framework::make_ddim(dim_vec)); + } else { + int64_t dim0 = input_dims[2] * input_dims[3] * num_priors; + this->param_.OutputBoxes()->Resize(framework::make_ddim({dim0, 4})); + this->param_.OutputVariances()->Resize(framework::make_ddim({dim0, 4})); + } +} +#endif // DENSITY_PRIORBOX_OP } // namespace operators } // namespace paddle_mobile namespace ops = paddle_mobile::operators; + #ifdef PADDLE_MOBILE_CPU +#ifdef PRIORBOX_OP REGISTER_OPERATOR_CPU(prior_box, ops::PriorBoxOp); -#endif +#endif // PRIORBOX_OP +#ifdef DENSITY_PRIORBOX_OP +REGISTER_OPERATOR_CPU(density_prior_box, ops::DensityPriorBoxOp); +#endif // DENSITY_PRIORBOX_OP +#endif // PADDLE_MOBILE_CPU + #ifdef PADDLE_MOBILE_CL +#ifdef PRIORBOX_OP REGISTER_OPERATOR_CL(prior_box, ops::PriorBoxOp); -#endif -#endif +#endif // PRIORBOX_OP +#endif // PADDLE_MOBILE_CL diff --git a/src/operators/prior_box_op.h b/src/operators/prior_box_op.h index 67d0cc6865fdc722a0191bc540a4d69c34ebedba..7a3c0466a01bb39131972fe699c66c5aa53f6a54 100644 --- a/src/operators/prior_box_op.h +++ b/src/operators/prior_box_op.h @@ -12,12 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PRIORBOX_OP - #pragma once #include - #include "framework/operator.h" #include "operators/kernel/prior_box_kernel.h" #include "operators/op_param.h" @@ -25,26 +22,13 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using paddle_mobile::framework::Tensor; - -template -class PriorBoxOp : public framework::OperatorWithKernel< - DeviceType, PriorBoxParam, - operators::PriorBoxKernel> { - public: - PriorBoxOp(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap &attrs, framework::Scope *scope) - : framework::OperatorWithKernel, - operators::PriorBoxKernel>( - type, inputs, outputs, attrs, scope) {} - - void InferShape() const override; +#ifdef PRIORBOX_OP +DECLARE_OPERATOR(PriorBox, PriorBoxParam, PriorBoxKernel); +#endif - protected: -}; +#ifdef DENSITY_PRIORBOX_OP +DECLARE_OPERATOR(DensityPriorBox, DensityPriorBoxParam, DensityPriorBoxKernel); +#endif } // namespace operators } // namespace paddle_mobile - -#endif