From 16ecead837c940d52109769c60791af874eee51a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 17:53:38 +0800 Subject: [PATCH] load op optimize --- paddle/fluid/operators/load_op.cc | 37 +++++++++---------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index e75fc4d67..6be8fdb0d 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -1,4 +1,5 @@ - +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,6 +36,8 @@ class LoadOp : public framework::OperatorBase { auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); platform::RecordEvent record_event(Type(), dev_ctx); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", @@ -46,31 +49,23 @@ class LoadOp : public framework::OperatorBase { out_var_name); if (out_var->IsType()) { - LoadLodTensor(filename, place, out_var); + LoadLodTensor(fin, place, out_var); } else if (out_var->IsType()) { - LoadSelectedRows(filename, scope, place, out_var); + LoadSelectedRows(fin, place, out_var); } else { PADDLE_ENFORCE( false, "Load only support LoDTensor and SelectedRows, %s has wrong type", out_var_name); } - } + } - void LoadLodTensor(const std::string &filename, const platform::Place &place, + void LoadLodTensor(std::istream &fin, const platform::Place &place, framework::Variable *var) const { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", - filename); - - auto *tensor = var->GetMutable(); - + auto *tensor = var->GetMutable(); DeserializeFromStream(fin, tensor, dev_ctx); auto load_as_fp16 = Attr("load_as_fp16"); @@ -92,25 +87,15 @@ class LoadOp : public framework::OperatorBase { tensor = var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); - } + } } - void LoadSelectedRows(const std::string &filename, - const framework::Scope &scope, - const platform::Place &place, + void LoadSelectedRows(std::istream &fin, const platform::Place &place, framework::Variable *var) const { - auto *selectedRows = var->GetMutable(); - // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), "Cannot open %s to read", - filename); framework::DeserializeFromStream(fin, selectedRows, dev_ctx); } }; -- GitLab