提交 549f0aa0 编写于 作者: T tangwei12

load op add seletedRows

上级 bb10c370
......@@ -44,7 +44,24 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
}
}
void LoadLodTensor(const std::string &filename, const platform::Place &place,
framework::Variable *var) const {
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx);
......@@ -67,7 +84,25 @@ class LoadOp : public framework::OperatorBase {
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(const std::string &filename,
const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const {
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open %s to write",
filename);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册