From a501766ab16362a0cc35d6ad75e68c35859df166 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 03:22:03 +0800 Subject: [PATCH] load op add seletedRows --- paddle/fluid/operators/load_op.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index dc5457dba87..7308330e74e 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); + if (out_var->IsType()) { + SaveLodTensor(filename, place, out_var); + } else if (out_var->IsType()) { + SaveSelectedRows(filename, scope, place, out_var); + } else { + PADDLE_ENFORCE( + false, + "Load only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } } } @@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase { const platform::Place &place, framework::Variable *var) const { - auto &selectedRows = var->Get(); + auto *selectedRows = var->GetMutable(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); -- GitLab