提交 3a28a80b 编写于 作者: Z Zhen Wang 提交者: Yan Chunwei

add precision and persistable attrs for the tensor. (#1899) (#1917)

* Add precision and persistable attrs for the tensor. And fix cxx light and full api demo.

* update precision2string methods. test=develop

* move the save logic to the front of the run in mobilenetv1_full_api.cc, test=develop.

* add comments for UpdateVarsOfProgram. test=develop
上级 443b2e68
...@@ -31,6 +31,7 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -31,6 +31,7 @@ void Predictor::SaveModel(const std::string &dir,
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->SaveOpInfosToProgram(&program_desc_); program_->SaveOpInfosToProgram(&program_desc_);
program_->UpdateVarsOfProgram(&program_desc_);
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), program_desc_, true); SaveModelPb(dir, *program_->exec_scope(), program_desc_, true);
......
...@@ -54,8 +54,15 @@ const std::string& TargetToStr(TargetType target) { ...@@ -54,8 +54,15 @@ const std::string& TargetToStr(TargetType target) {
} }
const std::string& PrecisionToStr(PrecisionType precision) { const std::string& PrecisionToStr(PrecisionType precision) {
static const std::string precision2string[] = { static const std::string precision2string[] = {"unk",
"unk", "float", "int8_t", "int32_t", "any", "float16", "bool"}; "float",
"int8_t",
"int32_t",
"any",
"float16",
"bool",
"int64_t",
"int16_t"};
auto x = static_cast<int>(precision); auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM))); CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x]; return precision2string[x];
...@@ -84,8 +91,15 @@ const std::string& TargetRepr(TargetType target) { ...@@ -84,8 +91,15 @@ const std::string& TargetRepr(TargetType target) {
} }
const std::string& PrecisionRepr(PrecisionType precision) { const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = { static const std::string precision2string[] = {"kUnk",
"kUnk", "kFloat", "kInt8", "kInt32", "kAny", "kFP16", "kBool"}; "kFloat",
"kInt8",
"kInt32",
"kAny",
"kFP16",
"kBool",
"kInt64",
"kInt16"};
auto x = static_cast<int>(precision); auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM))); CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x]; return precision2string[x];
......
...@@ -57,11 +57,13 @@ enum class PrecisionType : int { ...@@ -57,11 +57,13 @@ enum class PrecisionType : int {
kUnk = 0, kUnk = 0,
kFloat = 1, kFloat = 1,
kInt8 = 2, kInt8 = 2,
kFP16 = 5,
kInt32 = 3, kInt32 = 3,
kAny = 4, // any precision kAny = 4, // any precision
kFP16 = 5,
kBool = 6, kBool = 6,
NUM = 7, // number of fields. kInt64 = 7,
kInt16 = 8,
NUM = 9, // number of fields.
}; };
enum class DataLayoutType : int { enum class DataLayoutType : int {
kUnk = 0, kUnk = 0,
......
...@@ -958,7 +958,6 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { ...@@ -958,7 +958,6 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) {
int DeviceInfo::Setup() { int DeviceInfo::Setup() {
core_num_ = get_cpu_num(); core_num_ = get_cpu_num();
LOG(INFO) << " CPU core number: " << core_num_;
mem_size_ = get_mem_size(); mem_size_ = get_mem_size();
get_cpu_arch(&archs_, core_num_); get_cpu_arch(&archs_, core_num_);
// set defalut CPU info // set defalut CPU info
......
...@@ -175,6 +175,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -175,6 +175,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
for (size_t i = 0; i < weight_num; i++) { for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]); quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
} }
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
auto quantized_op = LiteOpRegistry::Global().Create(op_type_); auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope); quantized_op->Attach(op_desc, scope);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/core/program.h" #include "lite/core/program.h"
#include <unordered_map>
#include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/var_desc.h" #include "lite/model_parser/cpp/var_desc.h"
...@@ -38,6 +39,78 @@ void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { ...@@ -38,6 +39,78 @@ void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) {
} }
} }
// `UpdateVarsOfProgram` will remove unused var_descs and add new created
// vars' descs in the block 0. Now, the type of a new created var can only
// be LOD_TENSOR.
void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
CHECK(desc);
CHECK(desc->BlocksSize());
std::unordered_map<std::string, cpp::VarDesc> origin_var_maps;
auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
auto var_size = main_block.VarsSize();
for (int i = 0; i < var_size; i++) {
auto v = main_block.GetVar<cpp::VarDesc>(i);
auto name = v->Name();
origin_var_maps.emplace(name, *v);
}
main_block.ClearVars();
for (auto& node : instructions_) {
auto* op = const_cast<lite::OpLite*>(node.op());
auto* kernel = node.kernel();
auto* scope = op->scope();
auto in_names = op->op_info()->input_names();
auto out_names = op->op_info()->output_names();
for (auto& in_name : in_names) {
auto it = origin_var_maps.find(in_name);
if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name());
v->SetType((it->second).GetType());
v->SetPersistable((it->second).Persistable());
} else {
// New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(in_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name;
op->op_info()->GetInputArgname(in_name, &in_arg_name);
auto type = kernel->GetInputDeclType(in_arg_name);
if (type->IsTensor()) {
auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
} else {
CHECK(false) << "unsupported var type";
}
}
}
for (auto& out_name : out_names) {
auto it = origin_var_maps.find(out_name);
if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name());
v->SetType((it->second).GetType());
v->SetPersistable((it->second).Persistable());
} else {
// New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(out_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string out_arg_name;
op->op_info()->GetOutputArgname(out_name, &out_arg_name);
auto type = kernel->GetOutputDeclType(out_arg_name);
if (type->IsTensor()) {
auto tensor = scope->FindVar(out_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
} else {
CHECK(false) << "unsupported var type";
}
}
}
}
}
void RuntimeProgram::Run() { void RuntimeProgram::Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr() VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
......
...@@ -137,8 +137,15 @@ class LITE_API RuntimeProgram { ...@@ -137,8 +137,15 @@ class LITE_API RuntimeProgram {
const std::vector<Instruction>& instructions() const { return instructions_; } const std::vector<Instruction>& instructions() const { return instructions_; }
// `SaveOpInfosToProgram` will update the op list(ops_) of the block 0
// in ProgramDesc.
void SaveOpInfosToProgram(cpp::ProgramDesc* desc); void SaveOpInfosToProgram(cpp::ProgramDesc* desc);
// `UpdateVarsOfProgram` will update the var list(vars_) of the block 0 in
// ProgramDesc. Namely, if a new var created in some passes, its var_desc will
// be added in vars_.
void UpdateVarsOfProgram(cpp::ProgramDesc* desc);
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<Instruction> instructions_;
......
...@@ -126,6 +126,12 @@ class TensorLite { ...@@ -126,6 +126,12 @@ class TensorLite {
LoD *mutable_lod() { return &lod_; } LoD *mutable_lod() { return &lod_; }
void set_lod(const LoD &lod) { lod_ = lod; } void set_lod(const LoD &lod) { lod_ = lod; }
PrecisionType precision() const { return precision_; }
void set_precision(PrecisionType precision) { precision_ = precision; }
bool persistable() const { return persistable_; }
void set_persistable(bool persistable) { persistable_ = persistable; }
// T is the data type and R is the return type // T is the data type and R is the return type
// For OpenCL, the return type can be cl::Buffer // For OpenCL, the return type can be cl::Buffer
// and the data type can be float/int8_t. // and the data type can be float/int8_t.
...@@ -177,6 +183,14 @@ class TensorLite { ...@@ -177,6 +183,14 @@ class TensorLite {
private: private:
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
// precision_ and persistable_ are only used for persistable vars.
// If your tensor wants to be saved and loaded correctly, you must
// set values of precision_ and persistable_ after updating it.
// If your tensor is just a temp tensor, such as activations,
// you can ignore these two attributes.
PrecisionType precision_{PrecisionType::kUnk};
bool persistable_{false};
DDimLite dims_; DDimLite dims_;
std::shared_ptr<Buffer> buffer_; std::shared_ptr<Buffer> buffer_;
LoD lod_; LoD lod_;
......
...@@ -18,5 +18,5 @@ mobilenetv1_full_api.o: mobilenetv1_full_api.cc ...@@ -18,5 +18,5 @@ mobilenetv1_full_api.o: mobilenetv1_full_api.cc
.PHONY: clean .PHONY: clean
clean: clean:
rm mobilenetv1_full_api.o rm -f mobilenetv1_full_api.o
rm mobilenetv1_full_api rm -f mobilenetv1_full_api
...@@ -18,5 +18,5 @@ mobilenetv1_full_api.o: mobilenetv1_full_api.cc ...@@ -18,5 +18,5 @@ mobilenetv1_full_api.o: mobilenetv1_full_api.cc
.PHONY: clean .PHONY: clean
clean: clean:
rm mobilenetv1_full_api.o rm -f mobilenetv1_full_api.o
rm mobilenetv1_full_api rm -f mobilenetv1_full_api
...@@ -18,5 +18,5 @@ mobilenetv1_light_api.o: mobilenetv1_light_api.cc ...@@ -18,5 +18,5 @@ mobilenetv1_light_api.o: mobilenetv1_light_api.cc
.PHONY: clean .PHONY: clean
clean: clean:
rm mobilenetv1_light_api.o rm -f mobilenetv1_light_api.o
rm mobilenetv1_light_api rm -f mobilenetv1_light_api
...@@ -18,5 +18,5 @@ mobilenetv1_light_api.o: mobilenetv1_light_api.cc ...@@ -18,5 +18,5 @@ mobilenetv1_light_api.o: mobilenetv1_light_api.cc
.PHONY: clean .PHONY: clean
clean: clean:
rm mobilenetv1_light_api.o rm -f mobilenetv1_light_api.o
rm mobilenetv1_light_api rm -f mobilenetv1_light_api
...@@ -24,6 +24,7 @@ using namespace paddle::lite_api; // NOLINT ...@@ -24,6 +24,7 @@ using namespace paddle::lite_api; // NOLINT
DEFINE_string(model_dir, "", "Model dir path."); DEFINE_string(model_dir, "", "Model dir path.");
DEFINE_string(optimized_model_dir, "", "Optimized model dir."); DEFINE_string(optimized_model_dir, "", "Optimized model dir.");
DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels");
int64_t ShapeProduction(const shape_t& shape) { int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1; int64_t res = 1;
...@@ -35,14 +36,27 @@ void RunModel() { ...@@ -35,14 +36,27 @@ void RunModel() {
// 1. Set CxxConfig // 1. Set CxxConfig
CxxConfig config; CxxConfig config;
config.set_model_dir(FLAGS_model_dir); config.set_model_dir(FLAGS_model_dir);
config.set_preferred_place(Place{TARGET(kARM), PRECISION(kFloat)}); std::vector<Place> valid_places{Place{TARGET(kARM), PRECISION(kFloat)}};
config.set_valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); if (FLAGS_prefer_int8_kernel) {
valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)});
config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)});
} else {
config.set_preferred_place(Place{TARGET(kARM), PRECISION(kFloat)});
}
config.set_valid_places(valid_places);
// 2. Create PaddlePredictor by CxxConfig // 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<CxxConfig>(config); CreatePaddlePredictor<CxxConfig>(config);
// 3. Prepare input data // 3. Save the optimized model
// WARN: The `predictor->SaveOptimizedModel` method must be executed
// before the `predictor->Run` method. Because some kernels' `PrepareForRun`
// method maybe change some parameters' values.
predictor->SaveOptimizedModel(FLAGS_optimized_model_dir,
LiteModelType::kNaiveBuffer);
// 4. Prepare input data
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0))); std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize(shape_t({1, 3, 224, 224})); input_tensor->Resize(shape_t({1, 3, 224, 224}));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
...@@ -50,20 +64,16 @@ void RunModel() { ...@@ -50,20 +64,16 @@ void RunModel() {
data[i] = 1; data[i] = 1;
} }
// 4. Run predictor // 5. Run predictor
predictor->Run(); predictor->Run();
// 5. Get output // 6. Get output
std::unique_ptr<const Tensor> output_tensor( std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0))); std::move(predictor->GetOutput(0)));
printf("Output dim: %d\n", output_tensor->shape()[1]); printf("Output dim: %d\n", output_tensor->shape()[1]);
for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) {
printf("Output[%d]: %f\n", i, output_tensor->data<float>()[i]); printf("Output[%d]: %f\n", i, output_tensor->data<float>()[i]);
} }
// 6. Save optimition model
predictor->SaveOptimizedModel(FLAGS_optimized_model_dir,
LiteModelType::kNaiveBuffer);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
...@@ -85,20 +85,23 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { ...@@ -85,20 +85,23 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
size_t size = tensor->dims().production() * SizeOfType(desc.data_type()); size_t size = tensor->dims().production() * SizeOfType(desc.data_type());
// alllocate memory // alllocate memory
switch (static_cast<int>(desc.data_type())) { switch (static_cast<int>(desc.data_type())) {
#define DO(desc, type) \ #define SET_TENSOR(desc, type, precision) \
case Type::VarType_Type_##desc: \ case Type::VarType_Type_##desc: \
buf = tensor->mutable_data<type>(); \ buf = tensor->mutable_data<type>(); \
break; tensor->set_precision(precision); \
// DO(BOOL, bool); break
DO(FP32, float);
DO(INT8, int8_t); // SET_TENSOR(BOOL, bool, PRECISION(kBool));
DO(INT16, int16_t); SET_TENSOR(FP32, float, PRECISION(kFloat));
DO(INT32, int32_t); SET_TENSOR(INT8, int8_t, PRECISION(kInt8));
DO(INT64, int64_t); SET_TENSOR(INT16, int16_t, PRECISION(kInt16));
#undef DO SET_TENSOR(INT32, int32_t, PRECISION(kInt32));
SET_TENSOR(INT64, int64_t, PRECISION(kInt64));
#undef SET_TENSOR
default: default:
LOG(FATAL) << "unknown type " << desc.data_type(); LOG(FATAL) << "unknown type " << desc.data_type();
} }
tensor->set_persistable(true);
is.read(static_cast<char *>(buf), size); is.read(static_cast<char *>(buf), size);
} }
...@@ -379,7 +382,22 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { ...@@ -379,7 +382,22 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
// void* protobuf message // void* protobuf message
framework::proto::VarType::TensorDesc desc; framework::proto::VarType::TensorDesc desc;
// TODO(Superjomn) support other data types. // TODO(Superjomn) support other data types.
desc.set_data_type(framework::proto::VarType_Type_FP32); switch (tensor.precision()) {
#define SET_DATA_TYPE(precision, type_desc) \
case precision: \
desc.set_data_type(type_desc); \
break
SET_DATA_TYPE(PRECISION(kFloat), framework::proto::VarType_Type_FP32);
SET_DATA_TYPE(PRECISION(kInt8), framework::proto::VarType_Type_INT8);
SET_DATA_TYPE(PRECISION(kInt16), framework::proto::VarType_Type_INT16);
SET_DATA_TYPE(PRECISION(kInt32), framework::proto::VarType_Type_INT32);
SET_DATA_TYPE(PRECISION(kInt64), framework::proto::VarType_Type_INT64);
#undef SET_DATA_TYPE
default:
LOG(FATAL) << "unknown precision type: "
<< PrecisionToStr(tensor.precision());
}
auto dims = tensor.dims(); auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -444,7 +462,22 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, ...@@ -444,7 +462,22 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
desc.SetLoD(tensor.lod()); desc.SetLoD(tensor.lod());
// TODO(sangoly): support other data types. // TODO(sangoly): support other data types.
desc.SetDataType(VarDescAPI::VarDataType::FP32); switch (tensor.precision()) {
#define SET_DATA_TYPE(precision, type_desc) \
case precision: \
desc.SetDataType(type_desc); \
break
SET_DATA_TYPE(PRECISION(kFloat), VarDescAPI::VarDataType::FP32);
SET_DATA_TYPE(PRECISION(kInt8), VarDescAPI::VarDataType::INT8);
SET_DATA_TYPE(PRECISION(kInt16), VarDescAPI::VarDataType::INT16);
SET_DATA_TYPE(PRECISION(kInt32), VarDescAPI::VarDataType::INT32);
SET_DATA_TYPE(PRECISION(kInt64), VarDescAPI::VarDataType::INT64);
#undef SET_DATA_TYPE
default:
LOG(FATAL) << "unknown precision type: "
<< PrecisionToStr(tensor.precision());
}
desc.SetDim(tensor.dims().Vectorize()); desc.SetDim(tensor.dims().Vectorize());
uint64_t size = tensor.memory_size(); uint64_t size = tensor.memory_size();
CHECK_LT(size, std::numeric_limits<std::streamsize>::max()) CHECK_LT(size, std::numeric_limits<std::streamsize>::max())
...@@ -452,16 +485,44 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, ...@@ -452,16 +485,44 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
if (tensor.target() == TARGET(kCUDA)) { if (tensor.target() == TARGET(kCUDA)) {
std::unique_ptr<float> tmp_buffer(new float[tensor.data_size()]); switch (tensor.precision()) {
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), #define DO(precision, type) \
tensor.data<float>(), case precision: \
tensor.data_size(), std::unique_ptr<type> tmp_buffer(new type[tensor.data_size()]); \
IoDirection::DtoH); TargetWrapperCuda::MemcpySync(tmp_buffer.get(), \
desc.SetData<float>(tmp_buffer.get(), tensor.data_size()); tensor.data<type>(), \
tensor.data_size(), \
IoDirection::DtoH); \
desc.SetData<type>(tmp_buffer.get(), tensor.data_size()); \
break
DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t);
DO(PRECISION(kInt32), int32_t);
DO(PRECISION(kInt64), int64_t);
#undef DO
default:
LOG(FATAL) << "unknown precision type: "
<< PrecisionToStr(tensor.precision());
}
} else // NOLINT } else // NOLINT
#endif // LITE_WITH_CUDA #endif // LITE_WITH_CUDA
{ {
desc.SetData<float>(tensor.data<float>(), tensor.data_size()); switch (tensor.precision()) {
#define DO(precision, type) \
case precision: \
desc.SetData<type>(tensor.data<type>(), tensor.data_size()); \
break
DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t);
DO(PRECISION(kInt32), int32_t);
DO(PRECISION(kInt64), int64_t);
#undef DO
default:
LOG(FATAL) << "unknown precision type: "
<< PrecisionToStr(tensor.precision());
}
} }
} }
...@@ -566,22 +627,24 @@ void GetParamInfoNaive(const naive_buffer::ParamDesc &desc, ...@@ -566,22 +627,24 @@ void GetParamInfoNaive(const naive_buffer::ParamDesc &desc,
// Load data // Load data
switch (desc.GetDataType()) { switch (desc.GetDataType()) {
#define DO(data_type__, T) \ #define SET_TENSOR(data_type__, T, precision) \
case VarDescAPI::VarDataType::data_type__: \ case VarDescAPI::VarDataType::data_type__: \
SetTensorDataNaive<T>( \ SetTensorDataNaive<T>( \
tensor->mutable_data<T>(), tensor->data_size(), desc.Data<T>()); \ tensor->mutable_data<T>(), tensor->data_size(), desc.Data<T>()); \
tensor->set_precision(precision); \
break break
// DO(BOOL, bool); // SET_TENSOR(BOOL, bool, PRECISION(kBool));
DO(FP32, float); SET_TENSOR(FP32, float, PRECISION(kFloat));
DO(INT8, int8_t); SET_TENSOR(INT8, int8_t, PRECISION(kInt8));
DO(INT16, int16_t); SET_TENSOR(INT16, int16_t, PRECISION(kInt16));
DO(INT32, int32_t); SET_TENSOR(INT32, int32_t, PRECISION(kInt32));
DO(INT64, int64_t); SET_TENSOR(INT64, int64_t, PRECISION(kInt64));
#undef DO #undef SET_TENSOR
default: default:
LOG(FATAL) << "unknown type"; LOG(FATAL) << "unknown type";
} }
tensor->set_persistable(true);
} }
void LoadParamNaive(const std::string &path, void LoadParamNaive(const std::string &path,
......
...@@ -75,6 +75,8 @@ TEST(ModelParser, LoadModelCombinedPb) { ...@@ -75,6 +75,8 @@ TEST(ModelParser, LoadModelCombinedPb) {
TEST(ModelParser, SaveParamNaive) { TEST(ModelParser, SaveParamNaive) {
Scope scope; Scope scope;
auto* tensor = scope.Var("xxx")->GetMutable<lite::Tensor>(); auto* tensor = scope.Var("xxx")->GetMutable<lite::Tensor>();
tensor->set_precision(PRECISION(kFloat));
tensor->set_persistable(true);
auto& lod = *tensor->mutable_lod(); auto& lod = *tensor->mutable_lod();
lod.resize(2); lod.resize(2);
lod[0] = {1, 2, 3}; lod[0] = {1, 2, 3};
......
...@@ -169,7 +169,7 @@ GET_DATA_IMPL(double, FP64); ...@@ -169,7 +169,7 @@ GET_DATA_IMPL(double, FP64);
// NOTE: Must set data type first // NOTE: Must set data type first
#define SET_DATA_COMMON_IMPL(T, type__, size__, data_ptr__) \ #define SET_DATA_COMMON_IMPL(T, type__, size__, data_ptr__) \
CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \ CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \
<< "Data Type mismatch, Call SetDataType first"; \ << "Data Type mismatch, call SetDataType first."; \
auto* data_builder = \ auto* data_builder = \
desc_->GetMutableField<ListBuilder<CharBuilder>>("data"); \ desc_->GetMutableField<ListBuilder<CharBuilder>>("data"); \
CHECK(data_builder); \ CHECK(data_builder); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册