From ca27f78e299a86fc1aca2c087270a6133eb1a79e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 08:16:40 +0800 Subject: [PATCH] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 7308330e74..dd24dacf42 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -1,3 +1,4 @@ + /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase { out_var_name); if (out_var->IsType()) { - SaveLodTensor(filename, place, out_var); + LoadLodTensor(filename, place, out_var); } else if (out_var->IsType()) { - SaveSelectedRows(filename, scope, place, out_var); + LoadSelectedRows(filename, scope, place, out_var); } else { PADDLE_ENFORCE( false, "Load only support LoDTensor and SelectedRows, %s has wrong type", - iname); + out_var_name); } } - } void LoadLodTensor(const std::string &filename, const platform::Place &place, framework::Variable *var) const { - auto &tensor = var->Get(); - // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", filename); - auto *tensor = out_var->GetMutable(); + auto *tensor = var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); @@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase { &fp16_tensor); // reset output tensor - out_var->Clear(); - tensor = out_var->GetMutable(); + var->Clear(); + tensor = var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); + } } void LoadSelectedRows(const std::string &filename, @@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase { // FIXME(yuyang18): We save variable to local file now, but we should change // it to save an output stream. std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to write", + PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", filename); framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } -- GitLab