“8f156b5e2e222167172f1f2dae32ead2adcf9001”上不存在“develop/api_doc/fluid/executor.html”
提交 e7f32773 编写于 作者: S Superjomn

enable optimized model persist

上级 f1ca00a4
if(LITE_WITH_CUDA) if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda optimizer_lite)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
else() else()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
endif() endif()
...@@ -13,7 +13,15 @@ ...@@ -13,7 +13,15 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace lite {} // namespace lite namespace lite {
void Predictor::SaveModel(const std::string &dir) {
MkDirRecursively(dir.c_str());
program_->PersistModel(dir, program_desc_);
}
} // namespace lite
} // namespace paddle } // namespace paddle
...@@ -31,19 +31,19 @@ class Predictor { ...@@ -31,19 +31,19 @@ class Predictor {
void Build(const std::string& model_path, const Place& prefer_place, void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
framework::proto::ProgramDesc prog; LoadModel(model_path, scope_.get(), &program_desc_);
LoadModel(model_path, scope_.get(), &prog);
Program program(prog, scope_, valid_places); Program program(program_desc_, scope_, valid_places);
Optimizer optimizer; optimizer_.KernelPickPreferPlace(prefer_place);
optimizer.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
optimizer.Run(std::move(program), valid_places, factor); optimizer_.Run(std::move(program), valid_places, factor);
program_ = optimizer.GenRuntimeProgram(); program_ = optimizer_.GenRuntimeProgram();
} }
void SaveModel(const std::string& dir);
// Get offset-th col of feed. // Get offset-th col of feed.
Tensor* GetInput(size_t offset) { Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed"); auto* _feed_list = program_->exec_scope()->FindVar("feed");
...@@ -65,7 +65,13 @@ class Predictor { ...@@ -65,7 +65,13 @@ class Predictor {
void Run() { program_->Run(); } void Run() { program_->Run(); }
const framework::proto::ProgramDesc& program_desc() const {
return program_desc_;
}
private: private:
Optimizer optimizer_;
framework::proto::ProgramDesc program_desc_;
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
}; };
......
...@@ -36,7 +36,7 @@ TEST(CXXApi, test) { ...@@ -36,7 +36,7 @@ TEST(CXXApi, test) {
}); });
#endif #endif
predictor.Build("/home/chunwei/project2/models/model2", predictor.Build("/home/chunwei/project/models/model2",
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
...@@ -59,6 +59,15 @@ TEST(CXXApi, test) { ...@@ -59,6 +59,15 @@ TEST(CXXApi, test) {
LOG(INFO) << "out " << *out; LOG(INFO) << "out " << *out;
} }
TEST(CXXApi, save_model) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build("/home/chunwei/project/models/model2",
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
predictor.SaveModel("./optimized_model");
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
......
...@@ -12,13 +12,13 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l ...@@ -12,13 +12,13 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(type_system SRCS type_system.cc DEPS tensor_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager)
cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite scope_lite op_registry_lite proto_desc op_lite
ops_lite ops_lite
host_kernels host_kernels
) )
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
......
...@@ -96,6 +96,14 @@ class KernelBase { ...@@ -96,6 +96,14 @@ class KernelBase {
// Generate the key of the parameter type. // Generate the key of the parameter type.
std::string GenParamTypeKey() const; std::string GenParamTypeKey() const;
std::string SerializeKernelType() const {
std::stringstream ss;
ss << op_type() << "/";
ss << alias_ << "/";
ss << place();
return ss.str();
}
virtual ~KernelBase() = default; virtual ~KernelBase() = default;
void Torch() {} void Torch() {}
......
...@@ -22,7 +22,7 @@ namespace mir { ...@@ -22,7 +22,7 @@ namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
LOG(INFO) << "final program \n" << Visualize(graph.get()); LOG(INFO) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->InstructTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
LOG(INFO) << stmt; LOG(INFO) << stmt;
......
...@@ -31,6 +31,7 @@ class GenerateProgramPass : public ProgramPass { ...@@ -31,6 +31,7 @@ class GenerateProgramPass : public ProgramPass {
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram() { std::unique_ptr<RuntimeProgram> GenProgram() {
LOG(INFO) << "insts.size " << insts_.size();
std::unique_ptr<RuntimeProgram> program( std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_))); new RuntimeProgram(std::move(insts_)));
return program; return program;
......
...@@ -71,7 +71,7 @@ void SSAGraph::SortHelper( ...@@ -71,7 +71,7 @@ void SSAGraph::SortHelper(
ret->push_back(node); ret->push_back(node);
} }
std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() { std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
CheckBidirectionalConnection(); CheckBidirectionalConnection();
std::stack<mir::Node *> stack; std::stack<mir::Node *> stack;
......
...@@ -39,7 +39,7 @@ class SSAGraph : GraphBase { ...@@ -39,7 +39,7 @@ class SSAGraph : GraphBase {
mir::Node *Argument(const std::string &name); mir::Node *Argument(const std::string &name);
std::vector<mir::Node *> InstructTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
// The inputs of the graph. // The inputs of the graph.
std::vector<mir::Node *> inputs(); std::vector<mir::Node *> inputs();
......
...@@ -58,7 +58,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -58,7 +58,7 @@ class VariablePlaceInferencePass : public DebugPass {
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) { for (auto& x : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt(); auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue; if (inst.op_type == "io_copy") continue;
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/optimizer.h"
#include <fstream>
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" #include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" #include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -35,6 +36,7 @@ class Optimizer { ...@@ -35,6 +36,7 @@ class Optimizer {
void Run(Program&& program, const std::vector<Place>& valid_places, void Run(Program&& program, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor, core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) { const std::vector<std::string>& passes = {}) {
program_ = &program;
valid_places_ = valid_places; valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
...@@ -100,6 +102,11 @@ class Optimizer { ...@@ -100,6 +102,11 @@ class Optimizer {
return *graph_; return *graph_;
} }
mir::SSAGraph* mutable_ssa_graph() {
CHECK(graph_);
return graph_.get();
}
protected: protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor); void SpecifyKernelPickTactic(core::KernelPickFactor factor);
...@@ -117,6 +124,7 @@ class Optimizer { ...@@ -117,6 +124,7 @@ class Optimizer {
std::unique_ptr<mir::SSAGraph> graph_; std::unique_ptr<mir::SSAGraph> graph_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; lite::Scope* exec_scope_{};
Program* program_{};
}; };
} // namespace lite } // namespace lite
......
...@@ -28,15 +28,10 @@ TEST(Optimizer, test) { ...@@ -28,15 +28,10 @@ TEST(Optimizer, test) {
auto program = ProgramFaker(); auto program = ProgramFaker();
std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}});
auto* pick_pass = core::KernelPickFactor factor;
mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>( factor.ConsiderTarget();
"static_kernel_pick_pass");
ASSERT_TRUE(pick_pass != nullptr);
pick_pass->mutable_kernel_pick_factors()
->ConsiderTarget()
.ConsiderPrecision();
optimizer.Run(std::move(program), places); optimizer.Run(std::move(program), places, factor);
auto runtime_program = optimizer.GenRuntimeProgram(); auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num statements " << runtime_program->num_instructions(); LOG(INFO) << "num statements " << runtime_program->num_instructions();
} }
...@@ -45,4 +40,4 @@ TEST(Optimizer, test) { ...@@ -45,4 +40,4 @@ TEST(Optimizer, test) {
} // namespace paddle } // namespace paddle
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat, def); USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
...@@ -13,3 +13,59 @@ ...@@ -13,3 +13,59 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/optimizer.h"
namespace paddle {
namespace lite {
void RuntimeProgram::PersistModel(const std::string &path,
const framework::proto::ProgramDesc &desc) {
// Persist model.
const std::string model_path = path + "/__model__";
std::ofstream model_ostream(model_path, std::ios_base::binary);
CHECK(model_ostream.is_open());
const std::string pb_str = SerializeModelTopology(desc);
model_ostream.write(pb_str.c_str(), pb_str.size());
// Persist params.
const std::string params_path = path + "/params";
CHECK(!IsFileExists(params_path)) << "file " << params_path
<< " exists, can't overwrite";
std::ofstream params_ostream(params_path, std::ios_base::binary);
CHECK(params_ostream.is_open());
framework::proto::ProgramDesc latest_program;
latest_program.ParseFromString(pb_str);
SerializeParams(params_ostream, latest_program);
}
std::string RuntimeProgram::SerializeModelTopology(
const framework::proto::ProgramDesc &desc) {
const std::string kKernelTypeAttr = "__@kernel_type_attr@__";
auto program_dummy = desc;
program_dummy.mutable_blocks(0)->clear_ops();
for (auto &node : instructions_) {
auto desc_dummy = node.op()->op_info()->desc();
OpDesc desc(desc_dummy);
desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializeKernelType());
// append new opdesc
*program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto();
}
return program_dummy.SerializeAsString();
}
void RuntimeProgram::SerializeParams(
std::ostream &os, const framework::proto::ProgramDesc &desc) {
std::vector<std::string> ws;
for (auto &item : desc.blocks(0).vars()) {
if (item.name() == "feed" || item.name() == "fetch") continue;
if (item.persistable()) {
ws.push_back(item.name());
}
}
CHECK(exec_scope_);
SerializeTensors(os, *exec_scope_, ws);
}
} // namespace lite
} // namespace paddle
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -115,6 +116,9 @@ struct Instruction { ...@@ -115,6 +116,9 @@ struct Instruction {
return os; return os;
} }
const OpLite* op() const { return op_.get(); }
const KernelBase* kernel() const { return kernel_.get(); }
private: private:
std::shared_ptr<OpLite> op_; std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
...@@ -128,8 +132,8 @@ class RuntimeProgram { ...@@ -128,8 +132,8 @@ class RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruction>&& insts) explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(insts)) { : instructions_(std::move(insts)) {
if (insts.empty()) { if (instructions_.empty()) {
LOG(ERROR) << "no instructions"; LOG(FATAL) << "no instructions";
} }
} }
...@@ -140,11 +144,20 @@ class RuntimeProgram { ...@@ -140,11 +144,20 @@ class RuntimeProgram {
} }
} }
// Serialize the graph and save to the disk.
void PersistModel(const std::string& path,
const framework::proto::ProgramDesc& desc);
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } void set_exec_scope(lite::Scope* x) { exec_scope_ = x; }
lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* exec_scope() { return exec_scope_; }
size_t num_instructions() const { return instructions_.size(); } size_t num_instructions() const { return instructions_.size(); }
protected:
std::string SerializeModelTopology(const framework::proto::ProgramDesc& desc);
void SerializeParams(std::ostream& os,
const framework::proto::ProgramDesc& desc);
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<Instruction> instructions_;
......
...@@ -32,21 +32,19 @@ Program FakeProgram() { ...@@ -32,21 +32,19 @@ Program FakeProgram() {
auto b1v = program.scope->Var(b1)->GetMutable<Tensor>(); auto b1v = program.scope->Var(b1)->GetMutable<Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<Tensor>(); auto out1v = program.scope->Var(out1)->GetMutable<Tensor>();
framework::OpDesc desc; lite::OpDesc desc;
desc.SetInput("Input", {x}); desc.SetInput("Input", {x});
desc.SetInput("W", {w1}); desc.SetInput("W", {w1});
desc.SetInput("Bias", {b1}); desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1}); desc.SetOutput("Out", {out1});
desc.SetType("fc"); desc.SetType("fc");
desc.SetAttr("in_num_col_dims", 1); desc.SetAttr<int>("in_num_col_dims", 1);
desc.Flush();
// add to input // add to input
program.tmp_vars.push_back(w1); program.tmp_vars.push_back(w1);
program.tmp_vars.push_back(b1); program.tmp_vars.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope.get()); fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op)); program.ops.emplace_back(std::move(fc_op));
......
...@@ -164,6 +164,8 @@ class TargetWrapper { ...@@ -164,6 +164,8 @@ class TargetWrapper {
}; };
// This interface should be specified by each kind of target. // This interface should be specified by each kind of target.
using TargetWrapperHost = TargetWrapper<TARGET(kHost)>;
using TargetWrapperX86 = TargetWrapperHost;
template <> template <>
class TargetWrapper<TARGET(kHost)> { class TargetWrapper<TARGET(kHost)> {
public: public:
...@@ -196,6 +198,8 @@ class TargetWrapper<TARGET(kHost)> { ...@@ -196,6 +198,8 @@ class TargetWrapper<TARGET(kHost)> {
}; };
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
using TargetWrapperCuda =
TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
// This interface should be specified by each kind of target. // This interface should be specified by each kind of target.
template <> template <>
class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> { class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
......
...@@ -58,7 +58,7 @@ class Tensor { ...@@ -58,7 +58,7 @@ class Tensor {
const DDim& dims() const { return dims_; } const DDim& dims() const { return dims_; }
const LoD& lod() { return lod_; } const LoD& lod() const { return lod_; }
LoD* mutable_lod() { return &lod_; } LoD* mutable_lod() { return &lod_; }
template <typename T> template <typename T>
......
cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite)
cc_library(runtime_lite SRCS runtime.cc) cc_library(runtime_lite SRCS runtime.cc)
cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...@@ -7,5 +6,8 @@ else() ...@@ -7,5 +6,8 @@ else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
compatible_pb_lite)
add_subdirectory(pb) add_subdirectory(pb)
...@@ -37,6 +37,7 @@ int SizeOfType(framework::proto::VarType::Type type) { ...@@ -37,6 +37,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
default: default:
LOG(FATAL) << "unknown data type"; LOG(FATAL) << "unknown data type";
} }
return -1;
} }
void TensorFromStream(std::istream &is, lite::Tensor *tensor) { void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
...@@ -162,5 +163,73 @@ void LoadModel(const std::string &model_dir, Scope *scope, ...@@ -162,5 +163,73 @@ void LoadModel(const std::string &model_dir, Scope *scope,
} }
} }
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{
int size = tensor.lod().size();
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : tensor.lod()) {
size = each.size() * sizeof(each.front());
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::proto::VarType::TensorDesc desc;
desc.set_data_type(framework::proto::VarType_Type_LOD_TENSOR);
auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
CHECK_LT(size, std::numeric_limits<std::streamsize>::max())
<< "Index overflow when writing tensor";
#ifdef LITE_WITH_CUDA
if (tensor.target() == TARGET(kCUDA)) {
std::unique_ptr<char> tmp_buffer(new char[size]);
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<char>(),
tensor.memory_size(), IoDirection::DtoH);
os.write(static_cast<const char *>(tmp_buffer.get()),
static_cast<std::streamsize>(size));
} else
#endif // LITE_WITH_CUDA
{
os.write(static_cast<const char *>(tensor.data<void>()),
static_cast<std::streamsize>(size));
}
}
}
void SerializeTensors(std::ostream &os, const lite::Scope &scope,
const std::vector<std::string> &vars) {
// Store all the persistable vars.
for (const auto &_var : vars) {
auto *var = scope.FindVar(_var);
const auto &tensor = var->Get<lite::Tensor>();
TensorToStream(os, tensor);
}
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -40,5 +40,12 @@ void LoadParam(const std::string& path, Variable* out); ...@@ -40,5 +40,12 @@ void LoadParam(const std::string& path, Variable* out);
void LoadModel(const std::string& model_dir, Scope* scope, void LoadModel(const std::string& model_dir, Scope* scope,
framework::proto::ProgramDesc* prog); framework::proto::ProgramDesc* prog);
// Serialize tensors to ostream.
void SerializeTensors(std::ostream& os, const lite::Scope& scope,
const std::vector<std::string>& vars);
// LoDTensor to ostream
void TensorToStream(std::ostream& os, const lite::Tensor& tensor);
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -13,3 +13,31 @@ ...@@ -13,3 +13,31 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace paddle {
namespace lite {
namespace pb {
template <>
void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v) {
auto &xs = *desc_.mutable_attrs();
auto it = std::find_if(
xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; });
if (it == xs.end()) {
auto *attr = xs.Add();
attr->set_name(name);
it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
}
it->set_type(framework::proto::STRING);
it->set_s(v.c_str());
}
} // namespace pb
} // namespace lite
} // namespace paddle
...@@ -120,28 +120,24 @@ class OpDesc { ...@@ -120,28 +120,24 @@ class OpDesc {
if (it == xs.end()) { if (it == xs.end()) {
auto *attr = xs.Add(); auto *attr = xs.Add();
attr->set_name(name); attr->set_name(name);
it = std::find(xs.begin(), xs.end(), name); it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
} }
switch (typeid(T).hash_code()) { size_t hash = typeid(T).hash_code();
case typeid(int).hash_code(): if (hash == typeid(int).hash_code()) {
it->set_type(framework::proto::INT); it->set_type(framework::proto::INT);
it->set_i(v); it->set_i(v);
break; } else if (hash == typeid(float).hash_code()) {
case typeid(float).hash_code(): it->set_type(framework::proto::FLOAT);
it->set_type(framework::proto::FLOAT); it->set_f(v);
it->set_f(v); } else if (hash == typeid(bool).hash_code()) {
break; it->set_type(framework::proto::BOOLEAN);
case typeid(std::string).hash_code(): it->set_b(v);
it->set_type(framework::proto::STRING); } else {
it->set_s(v.c_str()); LOG(FATAL) << "unsupport attr type";
break;
case typeid(std::string).hash_code():
it->set_type(framework::proto::BOOLEAN);
it->set_b(v);
break;
default:
LOG(FATAL) << "unsupport attr type";
} }
} }
...@@ -229,6 +225,10 @@ class OpDesc { ...@@ -229,6 +225,10 @@ class OpDesc {
framework::proto::OpDesc desc_; framework::proto::OpDesc desc_;
}; };
template <>
void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v);
} // namespace pb } // namespace pb
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -17,5 +17,6 @@ ...@@ -17,5 +17,6 @@
#include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/hash.h" #include "paddle/fluid/lite/utils/hash.h"
#include "paddle/fluid/lite/utils/io.h"
#include "paddle/fluid/lite/utils/macros.h" #include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h" #include "paddle/fluid/lite/utils/varient.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <string>
namespace paddle {
namespace lite {
static bool IsFileExists(const std::string& path) {
std::ifstream file(path);
bool res = file.is_open();
if (res) {
file.close();
}
return res;
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册