提交 ca27f78e 编写于 作者: T tangwei12

load op add seletedRows

上级 a501766a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -45,22 +46,19 @@ class LoadOp : public framework::OperatorBase {
out_var_name);
if (out_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, out_var);
LoadLodTensor(filename, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, scope, place, out_var);
LoadSelectedRows(filename, scope, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
iname);
out_var_name);
}
}
}
void LoadLodTensor(const std::string &filename, const platform::Place &place,
framework::Variable *var) const {
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......@@ -68,10 +66,10 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx);
......@@ -90,10 +88,11 @@ class LoadOp : public framework::OperatorBase {
&fp16_tensor);
// reset output tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(const std::string &filename,
......@@ -110,7 +109,7 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to write",
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to read",
filename);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册