提交 461d2fc0 编写于 作者: T tangwei12

rename ckpt -> checkpoint

上级 a1419f10
...@@ -242,7 +242,7 @@ op_library(save_op DEPS lod_tensor) ...@@ -242,7 +242,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor)
op_library(ckpt_save_op DEPS lod_tensor) op_library(checkpoint_save_op DEPS lod_tensor)
op_library(concat_op DEPS concat) op_library(concat_op DEPS concat)
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency # FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
...@@ -278,6 +278,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea ...@@ -278,6 +278,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
cc_test(ckpt_save_op_test SRCS ckpt_save_op_test.cc DEPS ckpt_save_op) cc_test(checkpoint_save_op_test SRCS checkpoint_save_op_test.cc DEPS checkpoint_save_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP(ckpt_save) USE_NO_KERNEL_OP(checkpoint_save)
TEST(CkptSaveOp, CPU) { TEST(CheckpointSaveOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUPlace place; paddle::platform::CPUPlace place;
...@@ -41,6 +41,6 @@ TEST(CkptSaveOp, CPU) { ...@@ -41,6 +41,6 @@ TEST(CkptSaveOp, CPU) {
attrs.insert({"file_path", std::string("tensor.save")}); attrs.insert({"file_path", std::string("tensor.save")});
auto save_op = paddle::framework::OpRegistry::CreateOp( auto save_op = paddle::framework::OpRegistry::CreateOp(
"ckpt_save", {{"X", {"test_var"}}}, {}, attrs); "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place); save_op->Run(scope, place);
} }
...@@ -57,11 +57,12 @@ static void MkDirRecursively(const char *fullpath) { ...@@ -57,11 +57,12 @@ static void MkDirRecursively(const char *fullpath) {
MkDir(fullpath); MkDir(fullpath);
} }
class CkptSaveOp : public framework::OperatorBase { class CheckpointSaveOp : public framework::OperatorBase {
public: public:
CkptSaveOp(const std::string &type, const framework::VariableNameMap &inputs, CheckpointSaveOp(const std::string &type,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &inputs,
const framework::AttributeMap &attrs) const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
private: private:
...@@ -122,9 +123,9 @@ class CkptSaveOp : public framework::OperatorBase { ...@@ -122,9 +123,9 @@ class CkptSaveOp : public framework::OperatorBase {
} }
}; };
class CkptSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CkptSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) CheckpointSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -155,4 +156,5 @@ to a file on disk. ...@@ -155,4 +156,5 @@ to a file on disk.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ckpt_save, ops::CkptSaveOp, ops::CkptSaveOpProtoMaker); REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
ops::CheckpointSaveOpProtoMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册