提交 2e25e739 编写于 作者: T tangwei12

write checkpoint_load code simply

上级 c80125f2
...@@ -36,14 +36,6 @@ static bool FileExists(const std::string &filepath) { ...@@ -36,14 +36,6 @@ static bool FileExists(const std::string &filepath) {
return (stat(filepath.c_str(), &buffer) == 0); return (stat(filepath.c_str(), &buffer) == 0);
} }
static std::string DirName(const std::string &filepath) {
auto pos = filepath.rfind(kSEP);
if (pos == std::string::npos) {
return "";
}
return filepath.substr(0, pos);
}
class CheckpointLoadOp : public framework::OperatorBase { class CheckpointLoadOp : public framework::OperatorBase {
public: public:
CheckpointLoadOp(const std::string &type, CheckpointLoadOp(const std::string &type,
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP(checkpoint_save) USE_NO_KERNEL_OP(checkpoint_save)
USE_NO_KERNEL_OP(checkpoint_load)
TEST(CheckpointSaveOp, CPU) { TEST(CheckpointSaveOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
...@@ -37,10 +38,27 @@ TEST(CheckpointSaveOp, CPU) { ...@@ -37,10 +38,27 @@ TEST(CheckpointSaveOp, CPU) {
expect[i] = static_cast<float>(paddle::platform::float16(i)); expect[i] = static_cast<float>(paddle::platform::float16(i));
} }
scope.Var("SERIAL_NUMBER");
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("tensor/ckpt")}); attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp( auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); "checkpoint_save", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
attrs);
save_op->Run(scope, place);
}
TEST(CheckpointLoadOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
scope.Var("test_var");
paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});
auto save_op =
paddle::framework::OpRegistry::CreateOp("checkpoint_load", {}, {}, attrs);
save_op->Run(scope, place); save_op->Run(scope, place);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册