提交 a501766a 编写于 作者: T tangwei12

load op add seletedRows

上级 549f0aa0
...@@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase { ...@@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name); out_var_name);
if (out_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, scope, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
} }
} }
...@@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
const platform::Place &place, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto &selectedRows = var->Get<framework::SelectedRows>(); auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册