提交 ca27f78e 编写于 作者: T tangwei12

load op add seletedRows

上级 a501766a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase { ...@@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase {
out_var_name); out_var_name);
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, out_var); LoadLodTensor(filename, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, scope, place, out_var); LoadSelectedRows(filename, scope, place, out_var);
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, false,
"Load only support LoDTensor and SelectedRows, %s has wrong type", "Load only support LoDTensor and SelectedRows, %s has wrong type",
iname); out_var_name);
} }
} }
}
void LoadLodTensor(const std::string &filename, const platform::Place &place, void LoadLodTensor(const std::string &filename, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
...@@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase { ...@@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename); filename);
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx); DeserializeFromStream(fin, tensor, *dev_ctx);
...@@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase { ...@@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase {
&fp16_tensor); &fp16_tensor);
// reset output tensor // reset output tensor
out_var->Clear(); var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod()); tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor); tensor->ShareDataWith(fp16_tensor);
}
} }
void LoadSelectedRows(const std::string &filename, void LoadSelectedRows(const std::string &filename,
...@@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to write", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename); filename);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx); framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册