提交 47ff9c15 编写于 作者: L liuruilong

add load when predict

上级 69ae3e0b
......@@ -107,6 +107,10 @@ enum PoolingType {
AVG = 1,
};
struct PaddleMobileConfigInternal {
bool load_when_predict = false;
};
extern const char *G_OP_TYPE_CONV;
extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_BOX_CODER;
......
......@@ -37,6 +37,12 @@ namespace framework {
#pragma mark - executor
template <typename Device, typename T>
Executor<Device, T>::Executor(const Program<Device> &program, paddle_mobile::PaddleMobileConfigInternal config, int batch_size,
const bool use_optimize, const bool lod_mode): Executor(program, batch_size, use_optimize, lod_mode) {
config_ = config;
};
template <typename Device, typename T>
Executor<Device, T>::Executor(const Program<Device> &program, int batch_size,
const bool use_optimize, const bool lod_mode)
......@@ -212,10 +218,16 @@ void Executor<Device, T>::InitCombineMemory() {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
DLOG << " init combine memory persistable: " << var_desc->Name();
LoadMemory(reinterpret_cast<void **>(&data), var_desc, tensor);
} else {
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
DLOG << " init combine memory no persistable in lod: " << var_desc->Name();
varInputMemory(var_desc, var, tensor);
} else {
DLOG << " init combine memory no persistable: " << var_desc->Name();
}
}
}
......@@ -226,6 +238,32 @@ void Executor<Device, T>::InitCombineMemory() {
LOG(kLOG_INFO) << "init combine memory finish";
}
template <typename Device, typename T>
void Executor<Device, T>::InitNoPersistableMemory(const LoDTensor &input_tensor) {
for (const auto &block : program_desc_->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = program_.scope->Var(var_desc->Name());
auto tensor = var->template GetMutable<LoDTensor>();
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
} else {
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
DDim tensor_dim = tensor->dims();
DDim new_dim = make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2], input_tensor.dims()[3]});
tensor->template Resize(new_dim);
tensor->template mutable_data<T>();
}
}
}
}
std::shared_ptr<LoDTensor> output = GetOutput("fetch");
output->Resize(input_tensor.dims());
output->mutable_data<T>();
}
template <typename Device, typename T>
bool Executor<Device, T>::varInputMemory(
const std::shared_ptr<VarDesc> &var_desc, Variable *var,
......@@ -275,6 +313,7 @@ PMStatus Executor<Device, T>::Predict(
template <typename Device, typename T>
std::vector<T> Executor<Device, T>::Predict(const std::vector<T> &input,
const std::vector<int64_t> &dims) {
Tensor feed_tensor(input, make_ddim(dims));
SetInput(feed_tensor, "feed");
std::vector<T> output;
......@@ -293,7 +332,15 @@ void Executor<Device, T>::SetInput(const Tensor &input,
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
auto *target_tensor = target_var->template GetMutable<LoDTensor>();
if (config_.load_when_predict) {
if (target_tensor->IsInitialized() && target_tensor->dims() != input.dims()) {
InitNoPersistableMemory(*target_tensor);
}
}
target_tensor->Resize(input.dims());
target_tensor->ShareDataWith(input);
}
......@@ -301,10 +348,18 @@ void Executor<Device, T>::SetInput(const Tensor &input,
template <typename Device, typename T>
void Executor<Device, T>::SetInput(const LoDTensor &input,
const std::string &var_name) {
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
auto *target_tensor = target_var->template GetMutable<LoDTensor>();
if (config_.load_when_predict) {
if (target_tensor->IsInitialized() && target_tensor->dims() != input.dims()) {
InitNoPersistableMemory(*target_tensor);
}
}
target_tensor->Resize(input.dims());
target_tensor->ShareDataWith(input);
target_tensor->set_lod(input.lod());
......
......@@ -32,6 +32,8 @@ namespace framework {
template <typename Device, typename T = float>
class Executor {
public:
Executor(const Program<Device> &program, paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1,
const bool use_optimize = true, const bool lod_mode = false);
Executor(const Program<Device> &program, int batch_size = 1,
const bool use_optimize = true, const bool lod_mode = false);
......@@ -60,10 +62,13 @@ class Executor {
protected:
Executor() = default;
bool varInputMemory(const std::shared_ptr<VarDesc> &var_desc, Variable *var,
LoDTensor *tensor) const;
void InitMemory();
void InitCombineMemory();
void InitNoPersistableMemory(const LoDTensor &input_tensor);
void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc,
LoDTensor *tensor);
#ifdef PADDLE_MOBILE_CL
......@@ -73,14 +78,18 @@ class Executor {
int batch_size_;
bool use_optimize_;
bool lod_mode_;
PaddleMobileConfigInternal config_ = PaddleMobileConfigInternal();
Program<Device> program_;
std::shared_ptr<ProgramDesc> program_desc_;
typedef std::shared_ptr<OperatorBase<Device>> OperatorBasePtr;
std::vector<std::vector<OperatorBasePtr>> ops_of_block_;
// operators list
std::vector<OperatorBasePtr> ops_list_;
// for super resoltion
DDim input_dim_;
#ifdef PADDLE_MOBILE_PROFILE
struct ProfInfo {
int tid = 0;
......
......@@ -25,6 +25,7 @@ namespace framework {
template <typename Device = CPU, typename T = float>
class Loader {
public:
/*
* @b load separate format fluid model
* @b 加载分开存储的fluid模型
......@@ -59,6 +60,7 @@ class Loader {
void InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope);
};
} // namespace framework
......
......@@ -42,7 +42,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loader_->Load(dirname, optimize, quantification), config_, batch_size, optimize,
loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
......@@ -64,8 +64,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize, loddable);
loader_->Load(model_path, para_path, optimize, quantification), config_, batch_size, optimize, loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -87,7 +86,7 @@ bool PaddleMobile<Device, T>::LoadCombinedMemory(
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len,
combined_params_buf, optimize,
quantification),
quantification), config_,
batch_size, optimize, loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
......
......@@ -33,9 +33,18 @@ limitations under the License. */
namespace paddle_mobile {
template <typename Device, typename T = float>
class PaddleMobile {
public:
PaddleMobile(PaddleMobileConfigInternal config): config_(config){
#ifndef PADDLE_MOBILE_CL
bool is_gpu = std::is_same<DeviceType<kGPU_CL>, Device>::value;
PADDLE_MOBILE_ENFORCE(!is_gpu, "Please recompile with GPU_CL is on");
#endif
}
PaddleMobile() {
#ifndef PADDLE_MOBILE_CL
bool is_gpu = std::is_same<DeviceType<kGPU_CL>, Device>::value;
......@@ -100,6 +109,7 @@ class PaddleMobile {
private:
std::shared_ptr<framework::Loader<Device, T>> loader_;
std::shared_ptr<framework::Executor<Device, T>> executor_;
PaddleMobileConfigInternal config_;
};
} // namespace paddle_mobile
......@@ -18,7 +18,10 @@ limitations under the License. */
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile::PaddleMobileConfigInternal config;
config.load_when_predict = true;
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile(config);
// paddle_mobile.SetThreadNum(4);
auto time1 = paddle_mobile::time();
#ifdef PADDLE_MOBILE_CL
......@@ -27,7 +30,7 @@ int main() {
auto isok = paddle_mobile.Load(std::string(g_super) + "/model",
std::string(g_super) + "/params", true, false,
1, true);
1, false);
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_mul), true);
if (isok) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册