diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index ba8b5dbb51c10a05d5d2af81cb36b2d0ed65247c..026820ca3032473f90186e7deb16fad12a3c8145 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -62,7 +62,7 @@ class CheckpointLoadOp : public framework::OperatorBase { return; } - auto inp_var_names = Output("Out"); + auto inp_var_names = Inputs("X"); PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, "The number of input variables should be greater than 0"); // get device context from pool @@ -102,7 +102,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddOutput("Out", "(Tensor) The tensor need to be loaded"); + AddInput( + "X", + "(vector) Input LoDTensors that need to be saved together in a file.") + .AsDuplicable(); AddComment(R"DOC( CheckpointLoad operator