提交 29f9aade 编写于 作者: S Superjomn

add light api for mobile

上级 e7f32773
set(cxx_api_lite_deps scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host kernels_cuda optimizer_lite model_parser_lite)
if(LITE_WITH_CUDA) if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda optimizer_lite) cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} target_wrapper_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
else()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
endif() endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps})
set(light_api_deps
scope_lite host_kernels ops_lite target_wrapper_host model_parser_lite)
if(LITE_WITH_CUDA)
set(light_api_deps ${light_api_deps} target_wrapper_cuda)
endif()
cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps})
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void Predictor::SaveModel(const std::string &dir) { void CxxPredictor::SaveModel(const std::string &dir) {
MkDirRecursively(dir.c_str()); MkDirRecursively(dir.c_str());
program_->PersistModel(dir, program_desc_); program_->PersistModel(dir, program_desc_);
} }
......
...@@ -25,9 +25,9 @@ namespace lite { ...@@ -25,9 +25,9 @@ namespace lite {
struct Config {}; struct Config {};
class Predictor { class CxxPredictor {
public: public:
Predictor() { scope_ = std::make_shared<Scope>(); } CxxPredictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path, const Place& prefer_place, void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace lite { namespace lite {
TEST(CXXApi, test) { TEST(CXXApi, test) {
lite::Predictor predictor; lite::CxxPredictor predictor;
#ifndef LITE_WITH_CUDA #ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else #else
...@@ -60,7 +60,7 @@ TEST(CXXApi, test) { ...@@ -60,7 +60,7 @@ TEST(CXXApi, test) {
} }
TEST(CXXApi, save_model) { TEST(CXXApi, save_model) {
lite::Predictor predictor; lite::CxxPredictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build("/home/chunwei/project/models/model2", predictor.Build("/home/chunwei/project/models/model2",
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements a light-weight API which can run on mobile. We limit the
* dependencies and the runtime computation complexity.
*/
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace paddle {
namespace lite {
class CxxPredictor {
public:
CxxPredictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_dir) {
framework::proto::ProgramDesc desc;
LoadModel(model_dir, scope_.get(), &desc);
BuildRuntimeProgram(desc);
}
void Run() { program_->Run(); }
// Get offset-th col of feed.
Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
private:
void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) {
std::vector<Instruct> insts;
// 1. Create op first
Program program(prog, scope_, {});
// 2. Create Instructs
// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
auto kernels = op->CreateKernels({place});
// filter out a kernel
auto it = std::find_if(kernels.begin(), kernels.end(),
[&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope);
program_->set_exec_scope(program.exec_scope);
}
private:
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
const std::string model_dir =
"/home/chunwei/project/Paddle/cmake-build-relwithdebinfo/paddle/fluid/lite/"
"api/optimized_model";
TEST(LightAPI, load) {
CxxPredictor predictor;
predictor.Build(model_dir);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
}
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(fc);
USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
...@@ -21,7 +21,7 @@ cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) ...@@ -21,7 +21,7 @@ cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
......
...@@ -96,14 +96,40 @@ class KernelBase { ...@@ -96,14 +96,40 @@ class KernelBase {
// Generate the key of the parameter type. // Generate the key of the parameter type.
std::string GenParamTypeKey() const; std::string GenParamTypeKey() const;
std::string SerializeKernelType() const { std::string SerializedKernelType() const {
return SerializeKernelType(op_type(), alias(), place());
}
static std::string SerializeKernelType(const std::string& op_type,
const std::string& alias,
const Place& place) {
std::stringstream ss; std::stringstream ss;
ss << op_type() << "/"; ss << op_type << "/";
ss << alias_ << "/"; ss << alias << "/";
ss << place(); // We serialize the place value not the string representation here for
// easier deserialization.
ss << static_cast<int>(place.target) << "/";
ss << static_cast<int>(place.precision) << "/";
ss << static_cast<int>(place.layout);
return ss.str(); return ss.str();
} }
static void ParseKernelType(const std::string& kernel_type,
std::string* op_type, std::string* alias,
Place* place) {
std::stringstream ss(kernel_type);
std::getline(ss, *op_type, '/');
std::getline(ss, *alias, '/');
std::string target, precision, layout;
std::getline(ss, target, '/');
std::getline(ss, precision, '/');
std::getline(ss, layout, '/');
place->target = static_cast<TargetType>(std::stoi(target));
place->precision = static_cast<PrecisionType>(std::stoi(precision));
place->layout = static_cast<DataLayoutType>(std::stoi(layout));
}
virtual ~KernelBase() = default; virtual ~KernelBase() = default;
void Torch() {} void Torch() {}
......
...@@ -42,6 +42,22 @@ TEST(Kernel, test) { ...@@ -42,6 +42,22 @@ TEST(Kernel, test) {
ASSERT_EQ(test_code, 100); ASSERT_EQ(test_code, 100);
} }
TEST(Kernel, kernel_type) {
const std::string op_type = "fc";
const std::string alias = "def";
Place place(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
auto kernel_type = KernelBase::SerializeKernelType(op_type, alias, place);
LOG(INFO) << "kernel_type: " << kernel_type;
ASSERT_EQ(kernel_type, "fc/def/1/1/1");
std::string op_type1, alias1;
Place place1;
KernelBase::ParseKernelType(kernel_type, &op_type1, &alias1, &place1);
ASSERT_EQ(op_type, op_type1);
ASSERT_EQ(alias, alias1);
ASSERT_EQ(place, place1);
}
} // namespace core } // namespace core
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -38,7 +38,7 @@ class GenerateProgramPass : public ProgramPass { ...@@ -38,7 +38,7 @@ class GenerateProgramPass : public ProgramPass {
} }
private: private:
std::vector<Instruction> insts_; std::vector<Instruct> insts_;
}; };
} // namespace mir } // namespace mir
......
...@@ -18,53 +18,48 @@ ...@@ -18,53 +18,48 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void RuntimeProgram::PersistModel(const std::string &path, void RuntimeProgram::PersistModel(const std::string &dir,
const framework::proto::ProgramDesc &desc) { const framework::proto::ProgramDesc &desc) {
// Persist model. // Persist model.
const std::string model_path = path + "/__model__"; const std::string model_path = dir + "/__model__";
std::ofstream model_ostream(model_path, std::ios_base::binary); std::ofstream model_ostream(model_path, std::ios_base::binary);
CHECK(model_ostream.is_open()); CHECK(model_ostream.is_open());
const std::string pb_str = SerializeModelTopology(desc); const std::string pb_str = SerializeProgram(desc);
model_ostream.write(pb_str.c_str(), pb_str.size()); model_ostream.write(pb_str.c_str(), pb_str.size());
model_ostream.close();
// Persist params. // Persist params.
const std::string params_path = path + "/params";
CHECK(!IsFileExists(params_path)) << "file " << params_path
<< " exists, can't overwrite";
std::ofstream params_ostream(params_path, std::ios_base::binary);
CHECK(params_ostream.is_open());
framework::proto::ProgramDesc latest_program; framework::proto::ProgramDesc latest_program;
latest_program.ParseFromString(pb_str); latest_program.ParseFromString(pb_str);
SerializeParams(params_ostream, latest_program); SaveParams(dir, latest_program);
} }
std::string RuntimeProgram::SerializeModelTopology( std::string RuntimeProgram::SerializeProgram(
const framework::proto::ProgramDesc &desc) { const framework::proto::ProgramDesc &desc) {
const std::string kKernelTypeAttr = "__@kernel_type_attr@__";
auto program_dummy = desc; auto program_dummy = desc;
program_dummy.mutable_blocks(0)->clear_ops(); program_dummy.mutable_blocks(0)->clear_ops();
for (auto &node : instructions_) { for (auto &node : instructions_) {
auto desc_dummy = node.op()->op_info()->desc(); auto desc_dummy = node.op()->op_info()->desc();
OpDesc desc(desc_dummy); OpDesc desc(desc_dummy);
desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializeKernelType()); desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
// append new opdesc // append new opdesc
*program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto(); *program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto();
} }
return program_dummy.SerializeAsString(); return program_dummy.SerializeAsString();
} }
void RuntimeProgram::SerializeParams( void RuntimeProgram::SaveParams(const std::string &dir,
std::ostream &os, const framework::proto::ProgramDesc &desc) { const framework::proto::ProgramDesc &desc) {
std::vector<std::string> ws; CHECK(exec_scope_);
for (auto &item : desc.blocks(0).vars()) { for (auto &item : desc.blocks(0).vars()) {
const std::string path = dir + "/" + item.name();
if (item.name() == "feed" || item.name() == "fetch") continue; if (item.name() == "feed" || item.name() == "fetch") continue;
if (item.persistable()) { if (item.persistable()) {
ws.push_back(item.name()); std::ofstream file(path, std::ios::binary);
SerializeTensor(file, *exec_scope_, item.name());
file.close();
} }
} }
CHECK(exec_scope_);
SerializeTensors(os, *exec_scope_, ws);
} }
} // namespace lite } // namespace lite
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
static const std::string kKernelTypeAttr = "__@kernel_type_attr@__";
// A program is used to represent a code program, in Paddle, a code program // A program is used to represent a code program, in Paddle, a code program
// contains: // contains:
// - main block, which is a list of OpLite // - main block, which is a list of OpLite
...@@ -46,8 +48,9 @@ struct Program { ...@@ -46,8 +48,9 @@ struct Program {
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) { : scope(root), valid_places(valid_places), desc(desc) {
CHECK(scope) << "scope should be init first";
PrepareWorkspace(desc); PrepareWorkspace(desc);
Build(desc, valid_places); Build(desc);
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
...@@ -57,8 +60,7 @@ struct Program { ...@@ -57,8 +60,7 @@ struct Program {
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program, void Build(const framework::proto::ProgramDesc& program) {
const std::vector<Place>& valid_places) {
CHECK(ops.empty()) << "Executor duplicate Build found"; CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators. // Create operators.
...@@ -67,7 +69,9 @@ struct Program { ...@@ -67,7 +69,9 @@ struct Program {
auto op_type = op_desc.Type(); auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(op);
ops.back()->Attach(op_desc, exec_scope); ops.back()->Attach(op_desc, exec_scope);
} }
} }
...@@ -95,9 +99,9 @@ struct Program { ...@@ -95,9 +99,9 @@ struct Program {
} }
}; };
struct Instruction { struct Instruct {
Instruction(const std::shared_ptr<OpLite>& op, Instruct(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel) std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {} : op_(op), kernel_(std::move(kernel)) {}
void Run() { void Run() {
...@@ -111,7 +115,7 @@ struct Instruction { ...@@ -111,7 +115,7 @@ struct Instruction {
kernel_->Run(); kernel_->Run();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) { friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")"; os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os; return os;
} }
...@@ -130,7 +134,7 @@ struct Instruction { ...@@ -130,7 +134,7 @@ struct Instruction {
*/ */
class RuntimeProgram { class RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruction>&& insts) explicit RuntimeProgram(std::vector<Instruct>&& insts)
: instructions_(std::move(insts)) { : instructions_(std::move(insts)) {
if (instructions_.empty()) { if (instructions_.empty()) {
LOG(FATAL) << "no instructions"; LOG(FATAL) << "no instructions";
...@@ -145,7 +149,7 @@ class RuntimeProgram { ...@@ -145,7 +149,7 @@ class RuntimeProgram {
} }
// Serialize the graph and save to the disk. // Serialize the graph and save to the disk.
void PersistModel(const std::string& path, void PersistModel(const std::string& dir,
const framework::proto::ProgramDesc& desc); const framework::proto::ProgramDesc& desc);
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } void set_exec_scope(lite::Scope* x) { exec_scope_ = x; }
...@@ -154,13 +158,13 @@ class RuntimeProgram { ...@@ -154,13 +158,13 @@ class RuntimeProgram {
size_t num_instructions() const { return instructions_.size(); } size_t num_instructions() const { return instructions_.size(); }
protected: protected:
std::string SerializeModelTopology(const framework::proto::ProgramDesc& desc); std::string SerializeProgram(const framework::proto::ProgramDesc& desc);
void SerializeParams(std::ostream& os, void SaveParams(const std::string& dir,
const framework::proto::ProgramDesc& desc); const framework::proto::ProgramDesc& desc);
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<Instruct> instructions_;
lite::Scope* exec_scope_{}; lite::Scope* exec_scope_{};
}; };
......
cc_library(target_wrapper_host SRCS target_wrapper.cc) cc_library(target_wrapper_host SRCS target_wrapper.cc DEPS target_wrapper_lite)
...@@ -6,8 +6,13 @@ else() ...@@ -6,8 +6,13 @@ else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite
compatible_pb_lite) target_wrapper_host
compatible_pb_lite
)
if (LITE_WITH_CUDA)
set(model_parser_deps ${model_parser_deps} target_wrapper_cuda)
endif()
cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps})
add_subdirectory(pb) add_subdirectory(pb)
...@@ -35,7 +35,7 @@ int SizeOfType(framework::proto::VarType::Type type) { ...@@ -35,7 +35,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
DO(INT64, int64_t); DO(INT64, int64_t);
#undef DO #undef DO
default: default:
LOG(FATAL) << "unknown data type"; LOG(FATAL) << "unknown data type " << type;
} }
return -1; return -1;
} }
...@@ -86,12 +86,12 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { ...@@ -86,12 +86,12 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
void LoadLoDTensor(std::istream &is, Variable *var) { void LoadLoDTensor(std::istream &is, Variable *var) {
auto *tensor = var->GetMutable<lite::Tensor>(); auto *tensor = var->GetMutable<lite::Tensor>();
uint32_t version; uint32_t version{};
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
LOG(INFO) << "model version " << version; LOG(INFO) << "model version " << version;
// Load LoD information // Load LoD information
uint64_t lod_level; uint64_t lod_level{};
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level)); is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod(); auto &lod = *tensor->mutable_lod();
lod.resize(lod_level); lod.resize(lod_level);
...@@ -136,6 +136,7 @@ void LoadParams(const std::string &path) {} ...@@ -136,6 +136,7 @@ void LoadParams(const std::string &path) {}
// Load directly to CPU, and latter transfer to other devices. // Load directly to CPU, and latter transfer to other devices.
void LoadParam(const std::string &path, Variable *out) { void LoadParam(const std::string &path, Variable *out) {
std::ifstream fin(path, std::ios::binary); std::ifstream fin(path, std::ios::binary);
CHECK(fin.is_open()) << "failed to open file " << path;
LoadLoDTensor(fin, out); LoadLoDTensor(fin, out);
} }
...@@ -164,13 +165,12 @@ void LoadModel(const std::string &model_dir, Scope *scope, ...@@ -164,13 +165,12 @@ void LoadModel(const std::string &model_dir, Scope *scope,
} }
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
{ // the 1st field, uint32_t version // the 1st field, uint32_t version
constexpr uint32_t version = 0; constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version)); os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ {
int size = tensor.lod().size(); uint64_t size = tensor.lod().size();
// the 2st field, LoD information // the 2st field, LoD information
// uint64_t lod_level // uint64_t lod_level
// uint64_t lod_level_1 size in byte. // uint64_t lod_level_1 size in byte.
...@@ -186,11 +186,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { ...@@ -186,11 +186,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
} }
} }
// There are two version fields in a LoDTensor.
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
framework::proto::VarType::TensorDesc desc; framework::proto::VarType::TensorDesc desc;
desc.set_data_type(framework::proto::VarType_Type_LOD_TENSOR); // TODO(Superjomn) support other data types.
desc.set_data_type(framework::proto::VarType_Type_FP32);
auto dims = tensor.dims(); auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -221,14 +225,12 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { ...@@ -221,14 +225,12 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
} }
} }
void SerializeTensors(std::ostream &os, const lite::Scope &scope, void SerializeTensor(std::ostream &os, const lite::Scope &scope,
const std::vector<std::string> &vars) { const std::string &var_name) {
// Store all the persistable vars. // Store all the persistable vars.
for (const auto &_var : vars) { auto *var = scope.FindVar(var_name);
auto *var = scope.FindVar(_var); const auto &tensor = var->Get<lite::Tensor>();
const auto &tensor = var->Get<lite::Tensor>(); TensorToStream(os, tensor);
TensorToStream(os, tensor);
}
} }
} // namespace lite } // namespace lite
......
...@@ -41,8 +41,8 @@ void LoadModel(const std::string& model_dir, Scope* scope, ...@@ -41,8 +41,8 @@ void LoadModel(const std::string& model_dir, Scope* scope,
framework::proto::ProgramDesc* prog); framework::proto::ProgramDesc* prog);
// Serialize tensors to ostream. // Serialize tensors to ostream.
void SerializeTensors(std::ostream& os, const lite::Scope& scope, void SerializeTensor(std::ostream& os, const lite::Scope& scope,
const std::vector<std::string>& vars); const std::string& var);
// LoDTensor to ostream // LoDTensor to ostream
void TensorToStream(std::ostream& os, const lite::Tensor& tensor); void TensorToStream(std::ostream& os, const lite::Tensor& tensor);
......
...@@ -22,7 +22,7 @@ namespace lite { ...@@ -22,7 +22,7 @@ namespace lite {
template <> template <>
void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, const void *src, void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, const void *src,
size_t size, IoDirection dir) { size_t size, IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t *>(src), size, std::copy_n(reinterpret_cast<const uint8_t *>(src), size,
reinterpret_cast<uint8_t *>(dst)); reinterpret_cast<uint8_t *>(dst));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册