提交 ae7c7d90 编写于 作者: xiebaiyuan's avatar xiebaiyuan

add memory load high api

上级 ca99ebd9
......@@ -29,7 +29,14 @@ PaddleMobilePredictor<Dtype, P>::PaddleMobilePredictor(
template <typename Dtype, Precision P>
bool PaddleMobilePredictor<Dtype, P>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Dtype, P>());
if (!config.model_dir.empty()) {
if (config.memory_pack.from_memory) {
DLOG << "load from memory!";
paddle_mobile_->LoadCombinedMemory(config.memory_pack.model_size,
config.memory_pack.model_buf,
config.memory_pack.combined_params_size,
config.memory_pack.combined_params_buf);
} else if (!config.model_dir.empty()) {
paddle_mobile_->Load(config.model_dir, config.optimize,
config.quantification, config.batch_size);
} else if (!config.prog_file.empty() && !config.param_file.empty()) {
......
......@@ -111,6 +111,14 @@ class PaddlePredictor {
PaddlePredictor() = default;
};
struct PaddleModelMemoryPack {
bool from_memory = false;
size_t model_size = 0;
uint8_t* model_buf = nullptr;
size_t combined_params_size = 0;
uint8_t* combined_params_buf = nullptr;
};
struct PaddleMobileConfig : public PaddlePredictor::Config {
enum Precision { FP32 = 0 };
enum Device { kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
......@@ -124,6 +132,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
int thread_num = 1;
std::string prog_file;
std::string param_file;
struct PaddleModelMemoryPack memory_pack;
};
// A factory to help create different predictors.
......
......@@ -236,6 +236,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp)
target_link_libraries(test-loadmemory paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-loadmemory-inference framework/test_load_memory_inference_api.cpp)
target_link_libraries(test-loadmemory-inference paddle-mobile)
ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
target_link_libraries(test-inference-api paddle-mobile)
......
......@@ -58,9 +58,9 @@ int main() {
size_t sizeBuf = ReadBuffer(model_path.c_str(), &bufModel);
uint8_t *bufParams = nullptr;
DLOG << "sizeBuf: " << sizeBuf;
std::cout << "sizeBuf: " << sizeBuf << std::endl;
size_t sizeParams = ReadBuffer(params_path.c_str(), &bufParams);
DLOG << "sizeParams: " << sizeParams;
std::cout << "sizeParams: " << sizeParams << std::endl;
paddle_mobile.LoadCombinedMemory(sizeBuf, bufModel, sizeParams, bufParams);
return 0;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <iostream>
#include "../test_helper.h"
#include "io/paddle_inference_api.h"
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp;
fp = fopen(file_name, "rb");
PADDLE_MOBILE_ENFORCE(fp != nullptr, " %s open failed !", file_name);
fseek(fp, 0, SEEK_END);
auto size = static_cast<size_t>(ftell(fp));
rewind(fp);
DLOG << "model size: " << size;
*out = reinterpret_cast<uint8_t *>(malloc(size));
size_t cur_len = 0;
size_t nread;
while ((nread = fread(*out + cur_len, 1, size - cur_len, fp)) != 0) {
cur_len += nread;
}
fclose(fp);
return cur_len;
}
static char *Get_binary_data(std::string filename) {
FILE *file = fopen(filename.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
PADDLE_MOBILE_ENFORCE(size > 0, "size is too small");
rewind(file);
auto *data = new char[size];
size_t bytes_read = fread(data, 1, size, file);
PADDLE_MOBILE_ENFORCE(bytes_read == size,
"read binary file bytes do not match with fseek");
fclose(file);
return data;
}
paddle_mobile::PaddleMobileConfig GetConfig() {
paddle_mobile::PaddleMobileConfig config;
config.precision = paddle_mobile::PaddleMobileConfig::FP32;
config.device = paddle_mobile::PaddleMobileConfig::kCPU;
const std::shared_ptr<paddle_mobile::PaddleModelMemoryPack> &memory_pack =
std::make_shared<paddle_mobile::PaddleModelMemoryPack>();
auto model_path = std::string(g_genet_combine) + "/model";
auto params_path = std::string(g_genet_combine) + "/params";
memory_pack->model_size =
ReadBuffer(model_path.c_str(), &memory_pack->model_buf);
std::cout << "sizeBuf: " << memory_pack->model_size << std::endl;
memory_pack->combined_params_size =
ReadBuffer(params_path.c_str(), &memory_pack->combined_params_buf);
std::cout << "sizeParams: " << memory_pack->combined_params_size << std::endl;
memory_pack->from_memory = true;
config.memory_pack = *memory_pack;
config.thread_num = 4;
return config;
}
int main() {
paddle_mobile::PaddleMobileConfig config = GetConfig();
auto predictor = paddle_mobile::CreatePaddlePredictor<
paddle_mobile::PaddleMobileConfig,
paddle_mobile::PaddleEngineKind::kPaddleMobile>(config);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册