From fe76244f0ee363c194265ebac7abbcc9cf5e5e68 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 15 Jun 2018 15:13:02 +0800 Subject: [PATCH] bug fix --- paddle/fluid/operators/save_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 3d114538eb..3277d09ab2 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase { } SaveLodTensor(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &tensor = var->Get(); // get device context from pool @@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase { PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); + auto save_as_fp16 = Attr("save_as_fp16"); auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } - fout.close() + fout.close(); } SaveSelectedRows(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &selectedRows = var->Get(); // get device context from pool -- GitLab