diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index b2ca59f2b5b5bf34fa48433ed9f73abf429f5f73..8edf3b6429dbddfaa1c78d7953595f655bee3b85 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -36,14 +36,6 @@ static bool FileExists(const std::string &filepath) { return (stat(filepath.c_str(), &buffer) == 0); } -static std::string DirName(const std::string &filepath) { - auto pos = filepath.rfind(kSEP); - if (pos == std::string::npos) { - return ""; - } - return filepath.substr(0, pos); -} - class CheckpointLoadOp : public framework::OperatorBase { public: CheckpointLoadOp(const std::string &type, diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc index 7b5aa7bcde16eae6566a4edc7ccd71afe007e0f8..1445d9f9acffc9b104a6ee540b7e42b41df0b0f6 100644 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ b/paddle/fluid/operators/checkpoint_op_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" USE_NO_KERNEL_OP(checkpoint_save) +USE_NO_KERNEL_OP(checkpoint_load) TEST(CheckpointSaveOp, CPU) { paddle::framework::Scope scope; @@ -37,10 +38,27 @@ TEST(CheckpointSaveOp, CPU) { expect[i] = static_cast(paddle::platform::float16(i)); } + scope.Var("SERIAL_NUMBER"); + paddle::framework::AttributeMap attrs; - attrs.insert({"dir", std::string("tensor/ckpt")}); + attrs.insert({"dir", std::string("ckpt")}); auto save_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); + "checkpoint_save", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, + attrs); + save_op->Run(scope, place); +} + +TEST(CheckpointLoadOp, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + scope.Var("test_var"); + + paddle::framework::AttributeMap attrs; + attrs.insert({"dir", std::string("ckpt")}); + + auto save_op = + paddle::framework::OpRegistry::CreateOp("checkpoint_load", {}, {}, attrs); save_op->Run(scope, place); }