未验证 提交 9ad9635a 编写于 作者: xiebaiyuan's avatar xiebaiyuan 提交者: GitHub

Merge pull request #1504 from hjchen2/backup

Fix no effect if setting thread count in another thread
......@@ -14,42 +14,66 @@ limitations under the License. */
#pragma once
#if _OPENMP
#include <omp.h>
#endif
#define MOBILE_MAX_CPU_NUM 8
namespace paddle_mobile {
namespace operators {
namespace math {
namespace framework {
struct CPUInfo {
struct CPUContext {
private:
CPUInfo() {
CPUContext() : num_cpus(4), num_threads(1) {
// TODO(hjchen2)
num_cpus = 4;
for (int i = 0; i < num_cpus; ++i) {
cpu_frequency[i] = 2400; // 2400 MHz
max_cpu_frequency[i] = 2400; // 2400 MHz
cpu_frequencies[i] = 2400; // 2400 MHz
max_cpu_frequencies[i] = 2400; // 2400 MHz
}
// L1_cache = 32000; // 32K
L1_cache = 32 * 1024;
L2_cache = 2000000; // 2M
// L2_cache = 512000;
}
virtual ~CPUInfo() {}
public:
static CPUInfo* Info() {
static CPUInfo* ctx = new CPUInfo;
void set_num_threads(int threads) {
#if _ONENMP
omp_set_num_threads(threads);
if (threads <= omp_get_max_threads()) {
num_threads = threads;
} else {
num_threads = omp_get_max_threads();
}
#endif
num_threads = (num_threads > 1) ? num_threads : 1;
}
virtual ~CPUContext() {}
public:
static CPUContext* Context() {
static CPUContext* ctx = new CPUContext;
return ctx;
}
int num_cpus;
int cpu_frequency[MOBILE_MAX_CPU_NUM];
int max_cpu_frequency[MOBILE_MAX_CPU_NUM];
int num_threads;
int cpu_frequencies[MOBILE_MAX_CPU_NUM];
int max_cpu_frequencies[MOBILE_MAX_CPU_NUM];
int L1_cache;
int L2_cache;
};
} // namespace math
} // namespace operators
inline void set_global_num_threads(int threads) {
CPUContext::Context()->set_num_threads(threads);
}
inline int get_global_num_threads() {
return CPUContext::Context()->num_threads;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "common/enforce.h"
#include "common/log.h"
#include "framework/context.h"
#include "framework/framework.pb-c.h"
#include "framework/lod_tensor.h"
#include "framework/operator.h"
......@@ -37,6 +38,11 @@ namespace framework {
#pragma mark - executor
template <typename Device, typename T>
void Executor<Device, T>::SetThreadNum(int threads) {
set_global_num_threads(threads);
}
template <typename Device, typename T>
Executor<Device, T>::Executor(const Program<Device> &program,
paddle_mobile::PaddleMobileConfigInternal config,
......@@ -444,6 +450,9 @@ std::shared_ptr<LoDTensor> Executor<Device, T>::GetOutput(
template <typename Device, typename T>
PMStatus Executor<Device, T>::Predict() {
#if _OPENMP
omp_set_num_threads(get_global_num_threads());
#endif
#ifdef PADDLE_MOBILE_PROFILE
std::vector<ProfInfo> profile(ops_of_block0_.size());
struct timespec ts;
......@@ -654,14 +663,18 @@ void Executor<GPU_CL, float>::InitNoPersistableMemory(
output->Resize(input_tensor.dims());
output->mutable_data<float>();
}
template <>
void Executor<GPU_CL, float>::SetInput(const Tensor &input,
const std::string &var_name) {
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
int index = 0;
if (feed_indices_.find(var_name) != feed_indices_.end()) {
index = feed_indices_.find(var_name)->second;
}
auto *feed_var = program_.scope->Var("feed");
framework::LoDTensor *target_tensor =
&(feed_var->template GetMutable<framework::LoDTensorArray>()->at(index));
auto *target_tensor = target_var->template GetMutable<LoDTensor>();
DLOG << "config_.load_when_predict " << config_.load_when_predict;
DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized();
DLOG << "target_tensor->dims() " << target_tensor->dims();
......@@ -772,7 +785,7 @@ void Executor<GPU_CL, float>::InitMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<LoDTensor>();
var->template GetMutable<framework::LoDTensorArray>();
continue;
} else {
cl_image = var->template GetMutable<CLImage>();
......@@ -840,7 +853,7 @@ void Executor<GPU_CL, float>::InitCombineMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<LoDTensor>();
var->template GetMutable<framework::LoDTensorArray>();
continue;
} else {
cl_image = var->template GetMutable<CLImage>();
......
......@@ -36,6 +36,8 @@ class Executor {
paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1,
const bool use_optimize = true, const bool lod_mode = false);
void SetThreadNum(int threads);
PMStatus Predict(const std::vector<std::pair<std::string, Tensor>> &inputs);
PMStatus Predict(
const std::vector<std::pair<std::string, LoDTensor>> &inputs);
......
......@@ -28,9 +28,7 @@ namespace paddle_mobile {
template <typename Device, typename T>
void PaddleMobile<Device, T>::SetThreadNum(int num) {
#ifdef _OPENMP
omp_set_num_threads(num);
#endif
executor_->SetThreadNum(num);
}
template <typename Device, typename T>
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DENSITY_PRIORBOX_OP
#include "operators/kernel/prior_box_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DensityPriorBoxKernel<CPU, float>::Init(DensityPriorBoxParam<CPU> *param) {
return true;
}
template <>
void DensityPriorBoxKernel<CPU, float>::Compute(
const DensityPriorBoxParam<CPU> &param) {
// TODO(hjchen2)
}
} // namespace operators
} // namespace paddle_mobile
#endif // DENSITY_PRIORBOX_OP
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "common/types.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/math/pooling.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif // __ARM_NEON__
......@@ -44,7 +44,7 @@ void SequencePoolImpl(const framework::LoDTensor &input,
if (width == 1) {
float max = -std::numeric_limits<float>::max();
int remain_h = height;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __max4 = math::vPoolInitq_f32<MAX>();
......@@ -67,11 +67,11 @@ void SequencePoolImpl(const framework::LoDTensor &input,
in_ptr += width;
int remain_h = height - 1;
int remain_w_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
for (int w = 0; w < width; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
......@@ -104,7 +104,7 @@ void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
if (width == 1) {
float sum = 0.f;
int remain_h = height;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __sum4 = vdupq_n_f32(0.f);
......@@ -126,12 +126,12 @@ void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
in_ptr += width;
int remain_h = height - 1;
int remain_w_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
int loop_w = width >> 2;
remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#ifdef __ARM_NEON__
for (int w = 0; w < width - 3; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRIORBOX_OP
#pragma once
#include <algorithm>
......@@ -26,9 +24,10 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
#ifdef PRIORBOX_OP
inline void ExpandAspectRatios(const std::vector<float> &input_aspect_ratior,
bool flip,
std::vector<float>* output_aspect_ratior) {
std::vector<float> *output_aspect_ratior) {
constexpr float epsilon = 1e-6;
output_aspect_ratior->clear();
output_aspect_ratior->push_back(1.0f);
......@@ -50,14 +49,63 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
}
}
template <typename DeviceType, typename T>
class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam<DeviceType>> {
DECLARE_KERNEL(PriorBox, PriorBoxParam);
#endif // PRIORBOX_OP
#ifdef DENSITY_PRIORBOX_OP
template <typename Dtype>
class DensityPriorBoxParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
public:
DensityPriorBoxParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, *scope);
input_image_ = InputImageFrom<GType>(inputs, *scope);
output_boxes_ = OutputBoxesFrom<GType>(outputs, *scope);
output_variances_ = OutputVariancesFrom<GType>(outputs, *scope);
variances_ = GetAttr<vector<float>>("variances", attrs);
clip_ = GetAttr<bool>("clip", attrs);
flatten_to_2d_ = GetAttr<bool>("flatten_to_2d", attrs);
step_w_ = GetAttr<float>("step_w", attrs);
step_h_ = GetAttr<float>("step_h", attrs);
offset_ = GetAttr<float>("offset", attrs);
fixed_sizes_ = GetAttr<vector<float>>("fixed_sizes", attrs);
fixed_ratios_ = GetAttr<vector<float>>("fixed_ratios", attrs);
densities_ = GetAttr<vector<int>>("densities", attrs);
}
const GType *Input() const { return input_; }
const GType *InputImage() const { return input_image_; }
GType *OutputBoxes() const { return output_boxes_; }
GType *OutputVariances() const { return output_variances_; }
const bool Clip() const { return clip_; }
const bool FlattenTo2d() const { return flatten_to_2d_; }
const float StepW() const { return step_w_; }
const float StepH() const { return step_h_; }
const float Offset() const { return offset_; }
const vector<float> &FixedSizes() const { return fixed_sizes_; }
const vector<float> &FixedRatios() const { return fixed_ratios_; }
const vector<int> &Densities() const { return densities_; }
public:
void Compute(const PriorBoxParam<DeviceType>& param);
bool Init(PriorBoxParam<DeviceType>* param);
GType *input_;
GType *input_image_;
GType *output_boxes_ GType *output_variances_;
bool clip_;
bool flatten_to_2d_;
float step_w_;
float step_h_;
float offset_;
vector<float> fixed_sizes_;
vector<float> fixed_ratios_;
vector<int> densities_;
};
DECLARE_KERNEL(DensityPriorBox, DensityPriorBoxParam);
#endif // DENSITY_PRIORBOX_OP
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -17,7 +17,6 @@ limitations under the License. */
#pragma once
#include "operators/math/gemm/cblas.h"
#include "operators/math/gemm/cpu_info.h"
#include "operators/math/gemm/executor.h"
#include "operators/math/gemm/strategy.h"
......
......@@ -19,17 +19,17 @@ limitations under the License. */
#include <omp.h>
#endif
// #include <sys/time.h>
// #include <iostream>
#include <iostream>
#include "common/log.h"
#include "framework/context.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h"
#include "operators/math/gemm/gemm_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace math {
static CPUInfo *info = CPUInfo::Info();
static framework::CPUContext *g_cpu_ctx = framework::CPUContext::Context();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
......@@ -70,11 +70,11 @@ class GemmExecutor : public Executor {
unsigned int L1_size = 0;
unsigned int L2_size = 0;
if (M_ > N_) {
L2_size = ResetL1Cache(info->L1_cache, num_threads_, M_, K_);
L1_size = info->L2_cache;
L2_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, M_, K_);
L1_size = g_cpu_ctx->L2_cache;
} else {
L1_size = ResetL1Cache(info->L1_cache, num_threads_, N_, K_);
L2_size = info->L2_cache;
L1_size = ResetL1Cache(g_cpu_ctx->L1_cache, num_threads_, N_, K_);
L2_size = g_cpu_ctx->L2_cache;
}
rhs_tile_num_ = L1_size / (K_ * sizeof(Itype));
......
......@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRIORBOX_OP
#include "operators/prior_box_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
#ifdef PRIORBOX_OP
template <typename Dtype, typename T>
void PriorBoxOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.Input()->dims();
......@@ -44,15 +44,55 @@ void PriorBoxOp<Dtype, T>::InferShape() const {
this->param_.OutputBoxes()->Resize(framework::make_ddim(dim_vec));
this->param_.OutputVariances()->Resize(framework::make_ddim(dim_vec));
}
#endif // PRIORBOX_OP
#ifdef DENSITY_PRIORBOX_OP
template <typename Dtype, typename T>
void DensityPriorBoxOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.Input()->dims();
auto input_image_dims = this->param_.InputImage()->dims();
auto &fixed_sizes = this->param_.FixedSizes();
auto &fixed_ratios = this->param_.FixedRatios();
auto &densities = this->param_.Densities();
bool flatten = this->param_.FlattenTo2d();
size_t num_priors = 0;
for (size_t i = 0; i < densities.size(); ++i) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
if (!flatten) {
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
this->param_.OutputBoxes()->Resize(framework::make_ddim(dim_vec));
this->param_.OutputVariances()->Resize(framework::make_ddim(dim_vec));
} else {
int64_t dim0 = input_dims[2] * input_dims[3] * num_priors;
this->param_.OutputBoxes()->Resize(framework::make_ddim({dim0, 4}));
this->param_.OutputVariances()->Resize(framework::make_ddim({dim0, 4}));
}
}
#endif // DENSITY_PRIORBOX_OP
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
#ifdef PRIORBOX_OP
REGISTER_OPERATOR_CPU(prior_box, ops::PriorBoxOp);
#endif
#endif // PRIORBOX_OP
#ifdef DENSITY_PRIORBOX_OP
REGISTER_OPERATOR_CPU(density_prior_box, ops::DensityPriorBoxOp);
#endif // DENSITY_PRIORBOX_OP
#endif // PADDLE_MOBILE_CPU
#ifdef PADDLE_MOBILE_CL
#ifdef PRIORBOX_OP
REGISTER_OPERATOR_CL(prior_box, ops::PriorBoxOp);
#endif
#endif
#endif // PRIORBOX_OP
#endif // PADDLE_MOBILE_CL
......@@ -12,12 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRIORBOX_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/prior_box_kernel.h"
#include "operators/op_param.h"
......@@ -25,26 +22,13 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class PriorBoxOp : public framework::OperatorWithKernel<
DeviceType, PriorBoxParam<DeviceType>,
operators::PriorBoxKernel<DeviceType, T>> {
public:
PriorBoxOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs, framework::Scope *scope)
: framework::OperatorWithKernel<DeviceType, PriorBoxParam<DeviceType>,
operators::PriorBoxKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
#ifdef PRIORBOX_OP
DECLARE_OPERATOR(PriorBox, PriorBoxParam, PriorBoxKernel);
#endif
protected:
};
#ifdef DENSITY_PRIORBOX_OP
DECLARE_OPERATOR(DensityPriorBox, DensityPriorBoxParam, DensityPriorBoxKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册