diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index dc5457dba8794d6c161bb3eb397ce580aba9e30d..7308330e74e6748c63e856cbe441d1fd532ec873 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); + if (out_var->IsType()) { + SaveLodTensor(filename, place, out_var); + } else if (out_var->IsType()) { + SaveSelectedRows(filename, scope, place, out_var); + } else { + PADDLE_ENFORCE( + false, + "Load only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } } } @@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase { const platform::Place &place, framework::Variable *var) const { - auto &selectedRows = var->Get(); + auto *selectedRows = var->GetMutable(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();