提交 16ecead8 编写于 作者: T tangwei12

load op optimize

上级 49c2d0c5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -35,6 +36,8 @@ class LoadOp : public framework::OperatorBase { ...@@ -35,6 +36,8 @@ class LoadOp : public framework::OperatorBase {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx); platform::RecordEvent record_event(Type(), dev_ctx);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
...@@ -46,31 +49,23 @@ class LoadOp : public framework::OperatorBase { ...@@ -46,31 +49,23 @@ class LoadOp : public framework::OperatorBase {
out_var_name); out_var_name);
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(filename, place, out_var); LoadLodTensor(fin, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(filename, scope, place, out_var); LoadSelectedRows(fin, place, out_var);
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, false,
"Load only support LoDTensor and SelectedRows, %s has wrong type", "Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name); out_var_name);
} }
} }
void LoadLodTensor(const std::string &filename, const platform::Place &place, void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
...@@ -92,25 +87,15 @@ class LoadOp : public framework::OperatorBase { ...@@ -92,25 +87,15 @@ class LoadOp : public framework::OperatorBase {
tensor = var->GetMutable<framework::LoDTensor>(); tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod()); tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor); tensor->ShareDataWith(fp16_tensor);
} }
} }
void LoadSelectedRows(const std::string &filename, void LoadSelectedRows(std::istream &fin, const platform::Place &place,
const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>(); auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx); framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册