diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 04587cce2fc742fd64f8b9236f9175a7e5f11e89..32a4a56a209f215798ebcb039abbbb12d25d266f 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,7 +1,7 @@ if(LITE_WITH_CUDA) - cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) + cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda optimizer_lite) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) else() - cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) + cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels) endif() diff --git a/paddle/fluid/lite/api/cxx_api.cc b/paddle/fluid/lite/api/cxx_api.cc index 35a0373a25a9d71be3d6628e29ace02ce1789d02..3f05e0671d86dc0e49b8e97c643769d33727e761 100644 --- a/paddle/fluid/lite/api/cxx_api.cc +++ b/paddle/fluid/lite/api/cxx_api.cc @@ -13,7 +13,15 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/platform/port.h" namespace paddle { -namespace lite {} // namespace lite +namespace lite { + +void Predictor::SaveModel(const std::string &dir) { + MkDirRecursively(dir.c_str()); + program_->PersistModel(dir, program_desc_); +} + +} // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index ed2654c02a5e28c5571a8db2916ba1dbe078791a..0d61325eac6fdd71fe55a1b8d4b3611d921a4947 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -31,19 +31,19 @@ class Predictor { void Build(const std::string& model_path, const Place& prefer_place, const std::vector& valid_places) { - framework::proto::ProgramDesc prog; - LoadModel(model_path, scope_.get(), &prog); + LoadModel(model_path, scope_.get(), &program_desc_); - Program program(prog, scope_, valid_places); + Program program(program_desc_, scope_, valid_places); - Optimizer optimizer; - optimizer.KernelPickPreferPlace(prefer_place); + optimizer_.KernelPickPreferPlace(prefer_place); core::KernelPickFactor factor; factor.ConsiderTarget(); - optimizer.Run(std::move(program), valid_places, factor); - program_ = optimizer.GenRuntimeProgram(); + optimizer_.Run(std::move(program), valid_places, factor); + program_ = optimizer_.GenRuntimeProgram(); } + void SaveModel(const std::string& dir); + // Get offset-th col of feed. Tensor* GetInput(size_t offset) { auto* _feed_list = program_->exec_scope()->FindVar("feed"); @@ -65,7 +65,13 @@ class Predictor { void Run() { program_->Run(); } + const framework::proto::ProgramDesc& program_desc() const { + return program_desc_; + } + private: + Optimizer optimizer_; + framework::proto::ProgramDesc program_desc_; std::shared_ptr scope_; std::unique_ptr program_; }; diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 7397e837ab34d03610b69503022d2a462c0954ec..79e0b6862a411127d435069103a066f50d59bc38 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -36,7 +36,7 @@ TEST(CXXApi, test) { }); #endif - predictor.Build("/home/chunwei/project2/models/model2", + predictor.Build("/home/chunwei/project/models/model2", Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); auto* input_tensor = predictor.GetInput(0); @@ -59,6 +59,15 @@ TEST(CXXApi, test) { LOG(INFO) << "out " << *out; } +TEST(CXXApi, save_model) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); + predictor.Build("/home/chunwei/project/models/model2", + Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); + + predictor.SaveModel("./optimized_model"); +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 7b872d85593b21befcbb12fce3eab1c6a504f68b..b6a3602cc90338e53138a4cf6d5d12c5beca5fee 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -12,13 +12,13 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS tensor_lite) -cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager) cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph scope_lite op_registry_lite proto_desc op_lite ops_lite host_kernels ) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) +cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 0dee3968f6e7ec8fb9f93636c4541c2c69c541a5..70cc623bcfeff1a7b5321fcbc53979b650b4d732 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -96,6 +96,14 @@ class KernelBase { // Generate the key of the parameter type. std::string GenParamTypeKey() const; + std::string SerializeKernelType() const { + std::stringstream ss; + ss << op_type() << "/"; + ss << alias_ << "/"; + ss << place(); + return ss.str(); + } + virtual ~KernelBase() = default; void Torch() {} diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 62ca701b073112c1fc8220e201db1d80fc89eafe..ec671d7bbfa94031133ffca486a92563b99d53c9 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -22,7 +22,7 @@ namespace mir { void GenerateProgramPass::Apply(std::unique_ptr& graph) { LOG(INFO) << "final program \n" << Visualize(graph.get()); - for (auto& item : graph->InstructTopologicalOrder()) { + for (auto& item : graph->StmtTopologicalOrder()) { if (item->IsStmt()) { auto& stmt = item->AsStmt(); LOG(INFO) << stmt; diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index 0a16afb1bed4bafb1627d3694a75357a00db2a1c..222d7bc307083bd9637677f4d5e94afed04170d1 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -31,6 +31,7 @@ class GenerateProgramPass : public ProgramPass { void Apply(std::unique_ptr& graph) override; std::unique_ptr GenProgram() { + LOG(INFO) << "insts.size " << insts_.size(); std::unique_ptr program( new RuntimeProgram(std::move(insts_))); return program; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index e3a1a0ed4da45200165d88e42861656a59a36e8f..e807be78e6e47e9eef4677155a829cb0b6f2cecf 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -71,7 +71,7 @@ void SSAGraph::SortHelper( ret->push_back(node); } -std::vector SSAGraph::InstructTopologicalOrder() { +std::vector SSAGraph::StmtTopologicalOrder() { CheckBidirectionalConnection(); std::stack stack; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 6f860773b451e9b9b9acebb622ee576559de9b6e..bcdc963aff69bd2e3e68adf959485f163650f824 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -39,7 +39,7 @@ class SSAGraph : GraphBase { mir::Node *Argument(const std::string &name); - std::vector InstructTopologicalOrder(); + std::vector StmtTopologicalOrder(); // The inputs of the graph. std::vector inputs(); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index ae5621501daede0c4a2a812ff4ecbdbfb5dbc371..eb32111fdf69b46be8cbbb603495438327632414 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -58,7 +58,7 @@ class VariablePlaceInferencePass : public DebugPass { void InferenceArgumentPlace(SSAGraph* graph) { VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); - for (auto& x : graph->InstructTopologicalOrder()) { + for (auto& x : graph->StmtTopologicalOrder()) { auto& inst = x->AsStmt(); // The IoCopyOp is a tool operator, it won't support the type inference. if (inst.op_type == "io_copy") continue; diff --git a/paddle/fluid/lite/core/optimizer.cc b/paddle/fluid/lite/core/optimizer.cc index bb9fb5fe06760f2f7078b893157f6f1f65d058a8..1502d15e2bfa70b94a87686c72108e26175730b0 100644 --- a/paddle/fluid/lite/core/optimizer.cc +++ b/paddle/fluid/lite/core/optimizer.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "paddle/fluid/lite/core/optimizer.h" +#include #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" +#include "paddle/fluid/lite/model_parser/model_parser.h" +#include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index f7a4a7989c48c4bf1b32303a3fc9a499d1181db1..34272f570d04d981644706e2a03c35fa3621d705 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -22,6 +22,7 @@ #include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/model_parser/model_parser.h" namespace paddle { namespace lite { @@ -35,6 +36,7 @@ class Optimizer { void Run(Program&& program, const std::vector& valid_places, core::KernelPickFactor kernel_pick_factor, const std::vector& passes = {}) { + program_ = &program; valid_places_ = valid_places; CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!graph_) << "duplicate optimize found"; @@ -100,6 +102,11 @@ class Optimizer { return *graph_; } + mir::SSAGraph* mutable_ssa_graph() { + CHECK(graph_); + return graph_.get(); + } + protected: void SpecifyKernelPickTactic(core::KernelPickFactor factor); @@ -117,6 +124,7 @@ class Optimizer { std::unique_ptr graph_; std::vector valid_places_; lite::Scope* exec_scope_{}; + Program* program_{}; }; } // namespace lite diff --git a/paddle/fluid/lite/core/optimizer_test.cc b/paddle/fluid/lite/core/optimizer_test.cc index 85ae5981758dd6daf8eb13fcacdc524370ef6494..19a73a62cff81c7fbd478bdc3618b7e6d9be6641 100644 --- a/paddle/fluid/lite/core/optimizer_test.cc +++ b/paddle/fluid/lite/core/optimizer_test.cc @@ -28,15 +28,10 @@ TEST(Optimizer, test) { auto program = ProgramFaker(); std::vector places({Place{TARGET(kHost), PRECISION(kFloat)}}); - auto* pick_pass = - mir::PassManager::Global().LookUp( - "static_kernel_pick_pass"); - ASSERT_TRUE(pick_pass != nullptr); - pick_pass->mutable_kernel_pick_factors() - ->ConsiderTarget() - .ConsiderPrecision(); + core::KernelPickFactor factor; + factor.ConsiderTarget(); - optimizer.Run(std::move(program), places); + optimizer.Run(std::move(program), places, factor); auto runtime_program = optimizer.GenRuntimeProgram(); LOG(INFO) << "num statements " << runtime_program->num_instructions(); } @@ -45,4 +40,4 @@ TEST(Optimizer, test) { } // namespace paddle USE_LITE_OP(fc); -USE_LITE_KERNEL(fc, kHost, kFloat, def); +USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/core/program.cc b/paddle/fluid/lite/core/program.cc index 7a528740e388a0b1bb4363eca99b3181c623df02..9ebefe33e5e0b9e7a0fdcd1d285e9d5f7ceb6c76 100644 --- a/paddle/fluid/lite/core/program.cc +++ b/paddle/fluid/lite/core/program.cc @@ -13,3 +13,59 @@ // limitations under the License. #include "paddle/fluid/lite/core/program.h" +#include "paddle/fluid/lite/core/optimizer.h" + +namespace paddle { +namespace lite { + +void RuntimeProgram::PersistModel(const std::string &path, + const framework::proto::ProgramDesc &desc) { + // Persist model. + const std::string model_path = path + "/__model__"; + std::ofstream model_ostream(model_path, std::ios_base::binary); + CHECK(model_ostream.is_open()); + const std::string pb_str = SerializeModelTopology(desc); + model_ostream.write(pb_str.c_str(), pb_str.size()); + + // Persist params. + const std::string params_path = path + "/params"; + CHECK(!IsFileExists(params_path)) << "file " << params_path + << " exists, can't overwrite"; + std::ofstream params_ostream(params_path, std::ios_base::binary); + CHECK(params_ostream.is_open()); + framework::proto::ProgramDesc latest_program; + latest_program.ParseFromString(pb_str); + SerializeParams(params_ostream, latest_program); +} + +std::string RuntimeProgram::SerializeModelTopology( + const framework::proto::ProgramDesc &desc) { + const std::string kKernelTypeAttr = "__@kernel_type_attr@__"; + auto program_dummy = desc; + program_dummy.mutable_blocks(0)->clear_ops(); + for (auto &node : instructions_) { + auto desc_dummy = node.op()->op_info()->desc(); + OpDesc desc(desc_dummy); + desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializeKernelType()); + // append new opdesc + *program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto(); + } + return program_dummy.SerializeAsString(); +} + +void RuntimeProgram::SerializeParams( + std::ostream &os, const framework::proto::ProgramDesc &desc) { + std::vector ws; + for (auto &item : desc.blocks(0).vars()) { + if (item.name() == "feed" || item.name() == "fetch") continue; + if (item.persistable()) { + ws.push_back(item.name()); + } + } + + CHECK(exec_scope_); + SerializeTensors(os, *exec_scope_, ws); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 2d6560a59be0232d9350a4fbd6b5243d9304ab4b..49dfbca68ccd08ac461542cabb4b3b19aa076c4e 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -115,6 +116,9 @@ struct Instruction { return os; } + const OpLite* op() const { return op_.get(); } + const KernelBase* kernel() const { return kernel_.get(); } + private: std::shared_ptr op_; std::unique_ptr kernel_; @@ -128,8 +132,8 @@ class RuntimeProgram { public: explicit RuntimeProgram(std::vector&& insts) : instructions_(std::move(insts)) { - if (insts.empty()) { - LOG(ERROR) << "no instructions"; + if (instructions_.empty()) { + LOG(FATAL) << "no instructions"; } } @@ -140,11 +144,20 @@ class RuntimeProgram { } } + // Serialize the graph and save to the disk. + void PersistModel(const std::string& path, + const framework::proto::ProgramDesc& desc); + void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } lite::Scope* exec_scope() { return exec_scope_; } size_t num_instructions() const { return instructions_.size(); } + protected: + std::string SerializeModelTopology(const framework::proto::ProgramDesc& desc); + void SerializeParams(std::ostream& os, + const framework::proto::ProgramDesc& desc); + private: RuntimeProgram(const RuntimeProgram&) = delete; std::vector instructions_; diff --git a/paddle/fluid/lite/core/program_fake_utils.h b/paddle/fluid/lite/core/program_fake_utils.h index b27c90c3cfe0a4fd96e2732b04bf0805c018c495..e1dafc8ac522e296a868f881aeef995415c8cbd7 100644 --- a/paddle/fluid/lite/core/program_fake_utils.h +++ b/paddle/fluid/lite/core/program_fake_utils.h @@ -32,21 +32,19 @@ Program FakeProgram() { auto b1v = program.scope->Var(b1)->GetMutable(); auto out1v = program.scope->Var(out1)->GetMutable(); - framework::OpDesc desc; + lite::OpDesc desc; desc.SetInput("Input", {x}); desc.SetInput("W", {w1}); desc.SetInput("Bias", {b1}); desc.SetOutput("Out", {out1}); desc.SetType("fc"); - desc.SetAttr("in_num_col_dims", 1); - desc.Flush(); + desc.SetAttr("in_num_col_dims", 1); // add to input program.tmp_vars.push_back(w1); program.tmp_vars.push_back(b1); auto fc_op = LiteOpRegistry::Global().Create("fc"); - fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); fc_op->Attach(desc, program.scope.get()); program.ops.emplace_back(std::move(fc_op)); diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index fd754a09ec0c68400f33f7c892367557c61705d6..f14579bf6d724cf103bf698706acb47a039764ac 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -164,6 +164,8 @@ class TargetWrapper { }; // This interface should be specified by each kind of target. +using TargetWrapperHost = TargetWrapper; +using TargetWrapperX86 = TargetWrapperHost; template <> class TargetWrapper { public: @@ -196,6 +198,8 @@ class TargetWrapper { }; #ifdef LITE_WITH_CUDA +using TargetWrapperCuda = + TargetWrapper; // This interface should be specified by each kind of target. template <> class TargetWrapper { diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 2e8dc47ffc4ab0b3fb6644db015cd9926897b229..246bc5b214c01fcc18d659bfad37beab273ffdd3 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -58,7 +58,7 @@ class Tensor { const DDim& dims() const { return dims_; } - const LoD& lod() { return lod_; } + const LoD& lod() const { return lod_; } LoD* mutable_lod() { return &lod_; } template diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 18d4f15178936145a57a34e0ef5e24d923438b04..a732293a49ed6c69957a019a427dc586835cebe7 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -1,4 +1,3 @@ -cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite) cc_library(runtime_lite SRCS runtime.cc) cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -7,5 +6,8 @@ else() cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite +compatible_pb_lite) + add_subdirectory(pb) diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index feaff82d31354818559661d08b8371c603430e58..8ee31a2f8d57df9b68f6b6438225e3042be8f3d2 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -37,6 +37,7 @@ int SizeOfType(framework::proto::VarType::Type type) { default: LOG(FATAL) << "unknown data type"; } + return -1; } void TensorFromStream(std::istream &is, lite::Tensor *tensor) { @@ -162,5 +163,73 @@ void LoadModel(const std::string &model_dir, Scope *scope, } } +void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + + { + int size = tensor.lod().size(); + // the 2st field, LoD information + // uint64_t lod_level + // uint64_t lod_level_1 size in byte. + // int* lod_level_1 data + // ... + os.write(reinterpret_cast(&size), sizeof(size)); + + for (auto &each : tensor.lod()) { + size = each.size() * sizeof(each.front()); + os.write(reinterpret_cast(&size), sizeof(size)); + os.write(reinterpret_cast(each.data()), + static_cast(size)); + } + } + + { // the 2nd field, tensor description + // int32_t size + // void* protobuf message + framework::proto::VarType::TensorDesc desc; + desc.set_data_type(framework::proto::VarType_Type_LOD_TENSOR); + auto dims = tensor.dims(); + auto *pb_dims = desc.mutable_dims(); + pb_dims->Resize(static_cast(dims.size()), 0); + std::copy(dims.begin(), dims.end(), pb_dims->begin()); + int32_t size = desc.ByteSize(); + os.write(reinterpret_cast(&size), sizeof(size)); + auto out = desc.SerializeAsString(); + os.write(out.data(), size); + } + { // the 3rd field, tensor data + uint64_t size = tensor.memory_size(); + CHECK_LT(size, std::numeric_limits::max()) + << "Index overflow when writing tensor"; + +#ifdef LITE_WITH_CUDA + if (tensor.target() == TARGET(kCUDA)) { + std::unique_ptr tmp_buffer(new char[size]); + TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data(), + tensor.memory_size(), IoDirection::DtoH); + os.write(static_cast(tmp_buffer.get()), + static_cast(size)); + } else +#endif // LITE_WITH_CUDA + { + os.write(static_cast(tensor.data()), + static_cast(size)); + } + } +} + +void SerializeTensors(std::ostream &os, const lite::Scope &scope, + const std::vector &vars) { + // Store all the persistable vars. + for (const auto &_var : vars) { + auto *var = scope.FindVar(_var); + const auto &tensor = var->Get(); + TensorToStream(os, tensor); + } +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/model_parser.h b/paddle/fluid/lite/model_parser/model_parser.h index 41a5a9a93172df330064aeeb9f60fe54b8fee652..ef7a6752a4dc7546feb4fa6c2859c8bc03716243 100644 --- a/paddle/fluid/lite/model_parser/model_parser.h +++ b/paddle/fluid/lite/model_parser/model_parser.h @@ -40,5 +40,12 @@ void LoadParam(const std::string& path, Variable* out); void LoadModel(const std::string& model_dir, Scope* scope, framework::proto::ProgramDesc* prog); +// Serialize tensors to ostream. +void SerializeTensors(std::ostream& os, const lite::Scope& scope, + const std::vector& vars); + +// LoDTensor to ostream +void TensorToStream(std::ostream& os, const lite::Tensor& tensor); + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index c546eccc926689e1ca2f7d338972c17f7356df38..fb269cd067180b9df30aba27f9dd61b61b58279d 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -13,3 +13,31 @@ // limitations under the License. #include "paddle/fluid/lite/model_parser/pb/op_desc.h" + +namespace paddle { +namespace lite { +namespace pb { + +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v) { + auto &xs = *desc_.mutable_attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); + if (it == xs.end()) { + auto *attr = xs.Add(); + attr->set_name(name); + it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + } + + it->set_type(framework::proto::STRING); + it->set_s(v.c_str()); +} + +} // namespace pb +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index 7b1c362a1257a51b900796898cccd9098e54c157..054deec09e3b5f22afdf1bfee9a9bd6f6243305d 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -120,28 +120,24 @@ class OpDesc { if (it == xs.end()) { auto *attr = xs.Add(); attr->set_name(name); - it = std::find(xs.begin(), xs.end(), name); + it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); } - switch (typeid(T).hash_code()) { - case typeid(int).hash_code(): - it->set_type(framework::proto::INT); - it->set_i(v); - break; - case typeid(float).hash_code(): - it->set_type(framework::proto::FLOAT); - it->set_f(v); - break; - case typeid(std::string).hash_code(): - it->set_type(framework::proto::STRING); - it->set_s(v.c_str()); - break; - case typeid(std::string).hash_code(): - it->set_type(framework::proto::BOOLEAN); - it->set_b(v); - break; - default: - LOG(FATAL) << "unsupport attr type"; + size_t hash = typeid(T).hash_code(); + if (hash == typeid(int).hash_code()) { + it->set_type(framework::proto::INT); + it->set_i(v); + } else if (hash == typeid(float).hash_code()) { + it->set_type(framework::proto::FLOAT); + it->set_f(v); + } else if (hash == typeid(bool).hash_code()) { + it->set_type(framework::proto::BOOLEAN); + it->set_b(v); + } else { + LOG(FATAL) << "unsupport attr type"; } } @@ -229,6 +225,10 @@ class OpDesc { framework::proto::OpDesc desc_; }; +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v); + } // namespace pb } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/utils/all.h b/paddle/fluid/lite/utils/all.h index d9a4867717e13e56a2ccc55b891e0db07d791265..ff1242163d6fffe1b1bd3c7415515dc82b120465 100644 --- a/paddle/fluid/lite/utils/all.h +++ b/paddle/fluid/lite/utils/all.h @@ -17,5 +17,6 @@ #include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/hash.h" +#include "paddle/fluid/lite/utils/io.h" #include "paddle/fluid/lite/utils/macros.h" #include "paddle/fluid/lite/utils/varient.h" diff --git a/paddle/fluid/lite/utils/io.h b/paddle/fluid/lite/utils/io.h new file mode 100644 index 0000000000000000000000000000000000000000..c4ac283eaea5b52e6c3797c518ed457314431f4a --- /dev/null +++ b/paddle/fluid/lite/utils/io.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace lite { + +static bool IsFileExists(const std::string& path) { + std::ifstream file(path); + bool res = file.is_open(); + if (res) { + file.close(); + } + return res; +} + +} // namespace lite +} // namespace paddle