diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 8f4b5049271c9592d2db268ea7ff2f5c8abc28b6..dc5457dba8794d6c161bb3eb397ce580aba9e30d 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,7 +44,24 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); - auto *tensor = out_var->GetMutable(); + } + } + + void LoadLodTensor(const std::string &filename, const platform::Place &place, + framework::Variable *var) const { + auto &tensor = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + + auto *tensor = out_var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); @@ -67,7 +84,25 @@ class LoadOp : public framework::OperatorBase { tensor = out_var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); - } + } + + void LoadSelectedRows(const std::string &filename, + const framework::Scope &scope, + const platform::Place &place, + framework::Variable *var) const { + + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to write", + filename); + framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } };