diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 7007ab9e1a1a726be7b1e2f42c5d49443995488e..7449352117b58a4cb7d89edacb5211b4c5edb9ed 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -87,7 +87,7 @@ class CheckpointSaveOp : public framework::OperatorBase { std::string *serial_num = serial_var->GetMutable(); serial_num->append("0"); dir.append("/"); - dir.append(serial_num); + dir.append(serial_num->c_str()); MkDirRecursively(dir.c_str()); auto inp_var_names = Inputs("X"); @@ -159,10 +159,29 @@ to a file on disk. } }; +class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("Serial").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class CheckpointSaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, - ops::CheckpointSaveOpProtoMaker); +REGISTER_OPERATOR(send_vars, ops::CheckpointSaveOp, + paddle::framework::EmptyGradOpMaker, + ops::CheckpointSaveOpProtoMaker, + ops::CheckpointSaveOpVarTypeInference, + ops::CheckpointSaveOpShapeInference);