From 886897ccf742f3c95714703b5ed925d35a56e46e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 16 May 2018 16:05:33 +0800 Subject: [PATCH] load implement --- paddle/fluid/operators/checkpoint_load_op.cc | 48 ++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index ec451c9f3f0..ba8b5dbb51c 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -47,13 +47,54 @@ class CheckpointLoadOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto dir = Attr("dir"); - bool is_present = FileExists(dir); + std::string dir = Attr("dir"); + + VLOG(3) << "Load checkpoint from dir: " << dir; + + std::string success; + success.append(dir); + success.append("/"); + success.append(SUCCESS); + + bool is_present = FileExists(success); if (!is_present) { + VLOG(3) << "can not find _SUCCESS from path: " << success; return; } - // UPDATE LATER ... + auto inp_var_names = Output("Out"); + PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, + "The number of input variables should be greater than 0"); + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // todo (tangwei) made it async + for (size_t i = 0; i < inp_var_names.size(); i++) { + auto *var = scope.FindVar(inp_var_names[i]); + + PADDLE_ENFORCE(var != nullptr, + "Cannot find variable %s for save_combine_op", + inp_var_names[i]); + PADDLE_ENFORCE(var->IsType(), + "SaveCombineOp only supports LoDTensor, %s has wrong type", + inp_var_names[i]); + + std::string var_file; + var_file.append(dir); + var_file.append("/"); + var_file.append(inp_var_names[i]); + VLOG(3) << "ready to load var: " << inp_var_names[i]; + + auto &tensor = var->Get(); + + std::ifstream fin(var_file); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", + var_file); + DeserializeFromStream(fin, tensor, *dev_ctx); + fin.close(); + VLOG(3) << " load var: " << inp_var_names[i] << " finished"; + } } }; @@ -61,6 +102,7 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "(Tensor) The tensor need to be loaded"); AddComment(R"DOC( CheckpointLoad operator -- GitLab