diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index ec451c9f3f0cd28a76442e136c3c91a8fa9f39c5..ba8b5dbb51c10a05d5d2af81cb36b2d0ed65247c 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -47,13 +47,54 @@ class CheckpointLoadOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto dir = Attr("dir"); - bool is_present = FileExists(dir); + std::string dir = Attr("dir"); + + VLOG(3) << "Load checkpoint from dir: " << dir; + + std::string success; + success.append(dir); + success.append("/"); + success.append(SUCCESS); + + bool is_present = FileExists(success); if (!is_present) { + VLOG(3) << "can not find _SUCCESS from path: " << success; return; } - // UPDATE LATER ... + auto inp_var_names = Output("Out"); + PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, + "The number of input variables should be greater than 0"); + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // todo (tangwei) made it async + for (size_t i = 0; i < inp_var_names.size(); i++) { + auto *var = scope.FindVar(inp_var_names[i]); + + PADDLE_ENFORCE(var != nullptr, + "Cannot find variable %s for save_combine_op", + inp_var_names[i]); + PADDLE_ENFORCE(var->IsType(), + "SaveCombineOp only supports LoDTensor, %s has wrong type", + inp_var_names[i]); + + std::string var_file; + var_file.append(dir); + var_file.append("/"); + var_file.append(inp_var_names[i]); + VLOG(3) << "ready to load var: " << inp_var_names[i]; + + auto &tensor = var->Get(); + + std::ifstream fin(var_file); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", + var_file); + DeserializeFromStream(fin, tensor, *dev_ctx); + fin.close(); + VLOG(3) << " load var: " << inp_var_names[i] << " finished"; + } } }; @@ -61,6 +102,7 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "(Tensor) The tensor need to be loaded"); AddComment(R"DOC( CheckpointLoad operator