diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 3d114538eb881bc526e51e1aa442f85e00c7647d..3277d09ab20ae64ada7f7a66220b286400d03753 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase { } SaveLodTensor(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &tensor = var->Get(); // get device context from pool @@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase { PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); + auto save_as_fp16 = Attr("save_as_fp16"); auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } - fout.close() + fout.close(); } SaveSelectedRows(const std::string &filename, const platform::Place &place, - Variable *var) { + framework::Variable *var) { auto &selectedRows = var->Get(); // get device context from pool