From 30b50dcf8cd07efedd3d99a36199f589b29a448a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 15 May 2018 17:23:48 +0800 Subject: [PATCH] fix Serial output type --- paddle/fluid/operators/checkpoint_save_op.cc | 25 +++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 7007ab9e1a1..7449352117b 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -87,7 +87,7 @@ class CheckpointSaveOp : public framework::OperatorBase { std::string *serial_num = serial_var->GetMutable(); serial_num->append("0"); dir.append("/"); - dir.append(serial_num); + dir.append(serial_num->c_str()); MkDirRecursively(dir.c_str()); auto inp_var_names = Inputs("X"); @@ -159,10 +159,29 @@ to a file on disk. } }; +class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("Serial").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class CheckpointSaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, - ops::CheckpointSaveOpProtoMaker); +REGISTER_OPERATOR(send_vars, ops::CheckpointSaveOp, + paddle::framework::EmptyGradOpMaker, + ops::CheckpointSaveOpProtoMaker, + ops::CheckpointSaveOpVarTypeInference, + ops::CheckpointSaveOpShapeInference); -- GitLab