提交 fe76244f 编写于 作者: T tangwei12

bug fix

上级 fb27c9a5
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase { ...@@ -70,7 +71,6 @@ class SaveOp : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16");
if (FileExists(filename) && !overwrite) { if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
...@@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -97,7 +97,7 @@ class SaveOp : public framework::OperatorBase {
} }
SaveLodTensor(const std::string &filename, const platform::Place &place, SaveLodTensor(const std::string &filename, const platform::Place &place,
Variable *var) { framework::Variable *var) {
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
...@@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -110,6 +110,7 @@ class SaveOp : public framework::OperatorBase {
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename); filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto in_dtype = framework::ToDataType(tensor.type()); auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
...@@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase { ...@@ -124,11 +125,11 @@ class SaveOp : public framework::OperatorBase {
} else { } else {
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
fout.close() fout.close();
} }
SaveSelectedRows(const std::string &filename, const platform::Place &place, SaveSelectedRows(const std::string &filename, const platform::Place &place,
Variable *var) { framework::Variable *var) {
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册