// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/lite/model_parser/model_parser.h" #include #include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/variable.h" namespace paddle { namespace lite { int SizeOfType(framework::proto::VarType::Type type) { using Type = framework::proto::VarType::Type; switch (static_cast(type)) { #define DO(desc, type) \ case Type::VarType_Type_##desc: \ return sizeof(type); DO(BOOL, bool); DO(FP16, float); DO(FP32, float); DO(INT8, int8_t); DO(INT32, int); DO(INT64, int64_t); #undef DO default: LOG(FATAL) << "unknown data type"; } } void TensorFromStream(std::istream &is, lite::Tensor *tensor) { using Type = framework::proto::VarType::Type; uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); CHECK_EQ(version, 0U) << "Only version 0 is supported"; // read tensor desc framework::proto::VarType::TensorDesc desc; { // int32_t size // proto buffer int32_t size; is.read(reinterpret_cast(&size), sizeof(size)); std::unique_ptr buf(new char[size]); is.read(reinterpret_cast(buf.get()), size); CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc"; } // read tensor std::vector dims; dims.reserve(static_cast(desc.dims().size())); std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); tensor->Resize(dims); void *buf; size_t size = product(tensor->dims()) * SizeOfType(desc.data_type()); // alllocate memory switch (static_cast(desc.data_type())) { #define DO(desc, type) \ case Type::VarType_Type_##desc: \ buf = tensor->mutable_data(); \ break; DO(BOOL, bool); DO(FP32, float); DO(INT8, int8_t); DO(INT16, int16_t); DO(INT32, int32_t); DO(INT64, int64_t); #undef DO default: LOG(FATAL) << "unknown type"; } is.read(static_cast(buf), size); } void LoadLoDTensor(std::istream &is, Variable *var) { auto *tensor = var->GetMutable(); uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); LOG(INFO) << "model version " << version; // Load LoD information uint64_t lod_level; is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); auto &lod = *tensor->mutable_lod(); lod.resize(lod_level); for (uint64_t i = 0; i < lod_level; ++i) { uint64_t size; is.read(reinterpret_cast(&size), sizeof(size)); std::vector tmp(size / sizeof(size_t)); is.read(reinterpret_cast(tmp.data()), static_cast(size)); lod[i] = tmp; } TensorFromStream(is, tensor); } // TODO(Superjomn) support SelectedRows. void ReadBinaryFile(const std::string &filename, std::string *contents) { std::ifstream fin(filename, std::ios::in | std::ios::binary); CHECK(fin.is_open()) << "Cannot open file " << filename; fin.seekg(0, std::ios::end); auto size = fin.tellg(); contents->clear(); contents->resize(size); fin.seekg(0, std::ios::beg); fin.read(&(contents->at(0)), contents->size()); fin.close(); } std::unique_ptr LoadProgram( const std::string &path) { std::string desc_str; ReadBinaryFile(path, &desc_str); std::unique_ptr main_program( new framework::proto::ProgramDesc); main_program->ParseFromString(desc_str); return main_program; } void LoadParams(const std::string &path) {} void LoadModel(const std::string &model_dir, Scope *scope) { const std::string prog_path = model_dir + "/__model__"; auto prog = LoadProgram(prog_path); auto main_block = prog->blocks(0); for (auto &var : main_block.vars()) { std::string file_path = model_dir + "/" + var.name(); std::ifstream file(file_path); LoadLoDTensor(file, scope->Var(var.name())); } } } // namespace lite } // namespace paddle