提交 8b6cc00a 编写于 作者: L liuruilong

add code comment

上级 4bde43dd
......@@ -68,93 +68,6 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
return cur_len;
}
template <typename Dtype, Precision P>
void Loader<Dtype, P>::LoadVar(framework::Variable *variable,
const framework::VarDesc &var_desc,
const std::string &file_path) {
auto tensor = variable->GetMutable<framework::LoDTensor>();
char *data = Get_binary_data(file_path);
// 1. version
uint32_t version = *(uint32_t *)data;
data += sizeof(uint32_t);
// 2 Lod information
uint32_t lod_level = *(uint64_t *)data;
data += sizeof(uint64_t);
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint32_t size = *(uint64_t *)data;
data += sizeof(uint64_t);
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *(size_t *)data;
}
lod[i] = tmp;
}
// 3. tensor version
uint32_t tensor_version = *(uint32_t *)data;
data += sizeof(uint32_t);
// 4. tensor desc
uint32_t size = *(int32_t *)data;
data += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = data[m];
}
const framework::TensorDesc &desc = var_desc.Tensor_desc();
PaddleMobile__Framework__Proto__VarType__TensorDesc *tensor_desc = NULL;
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
tensor->Resize(framework::make_ddim(desc.Dims()));
void *memory = tensor;
int type_size = 0;
switch (desc.DataType()) {
case framework::VARTYPE_TYPE_FP16:
type_size = 2;
break;
case framework::VARTYPE_TYPE_FP32:
type_size = 4;
memory = tensor->mutable_data<float>();
break;
case framework::VARTYPE_TYPE_FP64:
type_size = 8;
break;
case framework::VARTYPE_TYPE_INT32:
type_size = 4;
break;
case framework::VARTYPE_TYPE_INT64:
type_size = 8;
break;
case framework::VARTYPE_TYPE_BOOL:
type_size = 1;
break;
default:
break;
}
for (int n = 0; n < memory_size * type_size; ++n) {
static_cast<char *>(memory)[n] = data[n];
}
delete data;
}
template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &dirname, bool optimize) {
......
......@@ -20,29 +20,34 @@ limitations under the License. */
#include <vector>
#include "common/types.h"
#include "framework/lod_tensor.h"
#include "framework/tensor.h"
#include "framework/operator.h"
#include "framework/lod_tensor.h"
#include "framework/program/program.h"
#include "framework/tensor.h"
namespace paddle_mobile {
template <typename Dtype = CPU, Precision P = Precision::FP32>
class Loader {
public:
/*
* @b load separate format fluid model
* @b 加载分开形式的 fluid 模型
* */
const framework::Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false);
/*
* @b load combine format fluid mode
* @b 加载结合在一起格式的模型
* */
const framework::Program<Dtype, P> Load(const std::string &model_path,
const std::string &para_path,
bool optimize = false);
private:
const framework::Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false);
void LoadVar(framework::Variable *variable,
const framework::VarDesc &var_desc,
const std::string &file_path);
};
template <typename Dtype = CPU, Precision P = Precision::FP32>
......@@ -50,17 +55,28 @@ class Executor {
public:
typedef typename PrecisionTrait<P>::ptype Ptype;
/*
* @b init executor with program load by Loader class
* @b 用 loader load 的 program 实例化 executor
* */
Executor(const framework::Program<Dtype> p, int batch_size = 1,
bool use_optimize = true);
/*
* @b to predict
* */
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t);
/*
* @b to predict with vector and dim
*
* @b 使用 输入 和 输入的维度信息 进行预测
* */
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
protected:
Executor() = default;
void InitMemory();
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, char *&data);
......
......@@ -70,7 +70,6 @@ build_for_android() {
-DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \
-DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DCMAKE_LDFLAGS="-Wl,--gc-sections --icf=safe" \
-DANDROID_STL=c++_static \
-DANDROID=true \
-D"${NET}=true" \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册