From 549f0aa0d3ee482afdac53f72cc532f5f42e0382 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 03:16:38 +0800 Subject: [PATCH] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 39 +++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 8f4b5049271..dc5457dba87 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,7 +44,24 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); - auto *tensor = out_var->GetMutable(); + } + } + + void LoadLodTensor(const std::string &filename, const platform::Place &place, + framework::Variable *var) const { + auto &tensor = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + + auto *tensor = out_var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); @@ -67,7 +84,25 @@ class LoadOp : public framework::OperatorBase { tensor = out_var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); - } + } + + void LoadSelectedRows(const std::string &filename, + const framework::Scope &scope, + const platform::Place &place, + framework::Variable *var) const { + + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to write", + filename); + framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } }; -- GitLab