提交 91eae9cc 编写于 作者: T tangwei12

code style

上级 db6126ca
...@@ -69,7 +69,6 @@ class SaveOp : public framework::OperatorBase { ...@@ -69,7 +69,6 @@ class SaveOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto iname = Input("X"); auto iname = Input("X");
auto *var = scope.FindVar(iname); auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
...@@ -87,7 +86,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -87,7 +86,7 @@ class SaveOp : public framework::OperatorBase {
} }
} }
void SaveLodTensor( const platform::Place &place, void SaveLodTensor(const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
...@@ -132,8 +131,11 @@ class SaveOp : public framework::OperatorBase { ...@@ -132,8 +131,11 @@ class SaveOp : public framework::OperatorBase {
void SaveSelectedRows(const framework::Scope &scope, void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto *lt_var = scope.FindVar("loopup_table_path")->GetMutable<std::string>(); auto *lt_var =
PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable loopup_table_path for SaveSelectedRows"); scope.FindVar("loopup_table_path")->GetMutable<std::string>();
PADDLE_ENFORCE(
lt_var != nullptr,
"Can not find variable loopup_table_path for SaveSelectedRows");
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
...@@ -195,17 +197,11 @@ class SaveOpShapeInference : public framework::InferShapeBase { ...@@ -195,17 +197,11 @@ class SaveOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override {} void operator()(framework::InferShapeContext *ctx) const override {}
}; };
} } // namespace operators
} } // namespace paddle
// namespace operators
// namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference,
ops::SaveOpShapeInference); ops::SaveOpShapeInference);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册