提交 461d2fc0 编写于 作者: T tangwei12

rename ckpt -> checkpoint

上级 a1419f10
......@@ -242,7 +242,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(ckpt_save_op DEPS lod_tensor)
op_library(checkpoint_save_op DEPS lod_tensor)
op_library(concat_op DEPS concat)
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
......@@ -278,6 +278,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
cc_test(ckpt_save_op_test SRCS ckpt_save_op_test.cc DEPS ckpt_save_op)
cc_test(checkpoint_save_op_test SRCS checkpoint_save_op_test.cc DEPS checkpoint_save_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
......@@ -15,9 +15,9 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP(ckpt_save)
USE_NO_KERNEL_OP(checkpoint_save)
TEST(CkptSaveOp, CPU) {
TEST(CheckpointSaveOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
......@@ -41,6 +41,6 @@ TEST(CkptSaveOp, CPU) {
attrs.insert({"file_path", std::string("tensor.save")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"ckpt_save", {{"X", {"test_var"}}}, {}, attrs);
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place);
}
......@@ -57,11 +57,12 @@ static void MkDirRecursively(const char *fullpath) {
MkDir(fullpath);
}
class CkptSaveOp : public framework::OperatorBase {
class CheckpointSaveOp : public framework::OperatorBase {
public:
CkptSaveOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
CheckpointSaveOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
......@@ -122,9 +123,9 @@ class CkptSaveOp : public framework::OperatorBase {
}
};
class CkptSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CkptSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
CheckpointSaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......@@ -155,4 +156,5 @@ to a file on disk.
namespace ops = paddle::operators;
REGISTER_OPERATOR(ckpt_save, ops::CkptSaveOp, ops::CkptSaveOpProtoMaker);
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
ops::CheckpointSaveOpProtoMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册