Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6d53dcee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d53dcee
编写于
5月 17, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimized checkpoint serial number and folder
上级
4220b31d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
143 addition
and
91 deletion
+143
-91
paddle/fluid/operators/checkpoint_load_op.cc
paddle/fluid/operators/checkpoint_load_op.cc
+85
-36
paddle/fluid/operators/checkpoint_op_test.cc
paddle/fluid/operators/checkpoint_op_test.cc
+5
-5
paddle/fluid/operators/checkpoint_save_op.cc
paddle/fluid/operators/checkpoint_save_op.cc
+53
-50
未找到文件。
paddle/fluid/operators/checkpoint_load_op.cc
浏览文件 @
6d53dcee
...
...
@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
@@ -30,41 +34,24 @@ namespace operators {
constexpr
char
kSEP
=
'/'
;
// write empty file named _SUCCESS
const
char
SUCCESS
[]
=
"_SUCCESS"
;
const
char
SERIAL_VAR
[]
=
"SERIAL_NUMBER"
;
static
bool
FileExists
(
const
std
::
string
&
filepath
)
{
struct
stat
buffer
;
return
(
stat
(
filepath
.
c_str
(),
&
buffer
)
==
0
);
}
class
CheckpointLoadOp
:
public
framework
::
OperatorBase
{
public:
CheckpointLoadOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
VLOG
(
3
)
<<
"Load checkpoint from dir: "
<<
dir
;
std
::
string
success
;
success
.
append
(
dir
);
success
.
append
(
"/"
);
success
.
append
(
SUCCESS
);
bool
is_present
=
FileExists
(
success
);
if
(
!
is_present
)
{
VLOG
(
3
)
<<
"can not find _SUCCESS from path: "
<<
success
;
return
;
}
static
std
::
string
GenePath
(
const
std
::
string
&
dir
,
const
std
::
string
&
file
)
{
boost
::
filesystem
::
path
dir
(
dir
);
boost
::
filesystem
::
path
file
(
file
);
boost
::
filesystem
::
path
full_path
=
dir
/
file
;
return
full_path
;
}
auto
inp_var_names
=
Inputs
(
"X"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
inp_var_names
.
size
()),
0
,
"The number of input variables should be greater than 0"
);
static
void
LoadInputVars
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>
&
inp_var_names
,
const
std
::
string
&
dir
)
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -80,21 +67,76 @@ class CheckpointLoadOp : public framework::OperatorBase {
"SaveCombineOp only supports LoDTensor, %s has wrong type"
,
inp_var_names
[
i
]);
std
::
string
var_file
;
var_file
.
append
(
dir
);
var_file
.
append
(
"/"
);
var_file
.
append
(
inp_var_names
[
i
]);
VLOG
(
3
)
<<
"ready to load var: "
<<
inp_var_names
[
i
];
std
::
string
var_file
=
GenePath
(
dir
,
inp_var_names
[
i
]);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
ifstream
fin
(
var_file
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
var_file
);
framework
::
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
fin
.
close
();
VLOG
(
3
)
<<
" load var: "
<<
inp_var_names
[
i
]
<<
" finished"
;
}
}
static
void
LoadStringArgv
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
string
&
argv
,
const
std
::
string
&
dir
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
argv
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
inp_var_names
[
i
]);
std
::
string
*
var_str
=
var
->
GetMutable
<
std
::
string
>
();
std
::
string
var_file
=
GenePath
(
dir
,
argv
);
std
::
ifstream
fin
(
var_file
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
var_file
);
std
::
getline
(
fin
,
var_str
);
fin
.
close
();
VLOG
(
3
)
<<
" load String argv: "
<<
argv
<<
" value is: "
<<
var_str
;
}
}
class
CheckpointLoadOp
:
public
framework
::
OperatorBase
{
public:
CheckpointLoadOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
int
serial_num
=
Attr
<
int
>
(
"Serial"
);
auto
*
serial_var
=
scope
.
FindVar
(
SERIAL_VAR
);
serial_var
=
serial_num
;
VLOG
(
1
)
<<
"CheckpointLoadOp set "
<<
SERIAL_NUMBER
<<
" value: "
<<
serial_num
;
std
::
string
success
;
=
GenePath
(
dir
,
std
::
to_string
(
serial_num
));
VLOG
(
3
)
<<
"Load checkpoint from dir: "
<<
success
;
success
=
GenePath
(
success
,
SUCCESS
);
bool
is_present
=
FileExists
(
success
);
if
(
!
is_present
)
{
VLOG
(
1
)
<<
"CheckpointLoadOp can not find "
<<
SUCCESS
<<
" from: "
<<
success
;
return
;
}
VLOG
(
3
)
<<
"Ready to load vars to scope"
;
auto
inp_var_names
=
Inputs
(
"X"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
inp_var_names
.
size
()),
0
,
"The number of input variables should be greater than 0"
);
LoadInputVars
(
scope
,
place
,
&
inp_var_names
);
VLOG
(
3
)
<<
"Ready to load string argv to scope"
;
auto
argv
=
Inputs
(
"Argv"
);
LoadStringArgv
(
scope
,
place
,
&
argv
,
&
dir
);
}
};
...
...
@@ -106,6 +148,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
AddInput
(
"Argv"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
CheckpointLoad operator
...
...
@@ -113,6 +159,9 @@ This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC"
);
AddAttr
<
int
>
(
"Serial"
,
"(int)"
"The serial number of the checkpoint will to be load."
);
AddAttr
<
std
::
string
>
(
"dir"
,
"(string)"
...
...
paddle/fluid/operators/checkpoint_op_test.cc
浏览文件 @
6d53dcee
...
...
@@ -44,8 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs
.
insert
({
"dir"
,
std
::
string
(
"ckpt"
)});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"checkpoint_save"
,
{{
"X"
,
{
"test_var"
}}},
{{
"Serial"
,
{
"SERIAL_NUMBER"
}}},
attrs
);
"checkpoint_save"
,
{{
"X"
,
{
"test_var"
}}},
attrs
);
save_op
->
Run
(
scope
,
place
);
}
...
...
@@ -58,7 +57,8 @@ TEST(CheckpointLoadOp, CPU) {
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"dir"
,
std
::
string
(
"ckpt"
)});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"checkpoint_load"
,
{{
"X"
,
{
"test_var"
}}},
{},
attrs
);
save_op
->
Run
(
scope
,
place
);
auto
load_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"checkpoint_load"
,
{{
"X"
,
{
"test_var"
}}},
{{
"Serial"
,
{
"SERIAL_NUMBER"
}}},
attrs
);
load_op
->
Run
(
scope
,
place
);
}
paddle/fluid/operators/checkpoint_save_op.cc
浏览文件 @
6d53dcee
...
...
@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
@@ -30,6 +34,14 @@ namespace operators {
constexpr
char
kSEP
=
'/'
;
// write empty file named _SUCCESS
const
char
SUCCESS
[]
=
"_SUCCESS"
;
const
char
SERIAL_VAR
[]
=
"SERIAL_NUMBER"
;
static
std
::
string
GenePath
(
const
std
::
string
&
dir
,
const
std
::
string
&
file
)
{
boost
::
filesystem
::
path
dir
(
dir
);
boost
::
filesystem
::
path
file
(
file
);
boost
::
filesystem
::
path
full_path
=
dir
/
file
;
return
full_path
;
}
static
bool
FileExists
(
const
std
::
string
&
filepath
)
{
struct
stat
buffer
;
...
...
@@ -72,24 +84,20 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto
dir
=
Attr
<
std
::
string
>
(
"dir"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
serial_num
=
scope
.
FindVar
(
SERIAL_VAR
);
if
(
serial_num
==
nullptr
)
{
serial_num
=
scope
.
Var
(
SERIAL_VAR
);
}
serial_num
=
serial_num
+
1
;
dir
=
GenePath
(
dir
,
std
::
to_string
(
serial_num
));
bool
is_present
=
FileExists
(
dir
);
if
(
is_present
&&
!
overwrite
)
{
return
;
// todo(tangwei) judge the folder is exist
// PADDLE_THROW("%s exists!, cannot save_combine to it when
// overwrite=false",
// dir, overwrite);
PADDLE_THROW
(
"%s exists!, checkpoint save cannot to overwrite it"
,
dir
,
overwrite
);
}
MkDirRecursively
(
dir
.
c_str
());
auto
serial_var_name
=
Output
(
"Serial"
);
auto
*
serial_var
=
scope
.
FindVar
(
serial_var_name
);
std
::
string
*
serial_num
=
serial_var
->
GetMutable
<
std
::
string
>
();
serial_num
->
append
(
"0"
);
dir
.
append
(
"/"
);
dir
.
append
(
serial_num
->
c_str
());
MkDirRecursively
(
dir
.
c_str
());
auto
inp_var_names
=
Inputs
(
"X"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
inp_var_names
.
size
()),
0
,
"The number of input variables should be greater than 0"
);
...
...
@@ -101,30 +109,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
// todo (tangwei) made it async
for
(
size_t
i
=
0
;
i
<
inp_var_names
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
inp_var_names
[
i
]);
std
::
string
var_file
;
var_file
.
append
(
dir
);
var_file
.
append
(
"/"
);
var_file
.
append
(
inp_var_names
[
i
]);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s for
save_combine_
op"
,
"Cannot find variable %s for
checkpoint save
op"
,
inp_var_names
[
i
]);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
"SaveCombineOp only supports LoDTensor, %s has wrong type"
,
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
"CheckpointSaveOp only supports LoDTensor, %s has wrong type"
,
inp_var_names
[
i
]);
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// Serialize tensors one by one
std
::
string
var_file
=
GenePath
(
dir
,
inp_var_names
[
i
]);
std
::
ofstream
fout
(
var_file
);
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
fout
.
close
();
}
std
::
string
success
;
success
.
append
(
dir
);
success
.
append
(
"/"
);
success
.
append
(
SUCCESS
);
std
::
string
success
=
GenePath
(
dir
,
SUCCESS
);
std
::
ofstream
fout
(
success
);
fout
.
close
();
}
...
...
@@ -138,7 +140,6 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
AddOutput
(
"Serial"
,
"the serial number"
);
AddComment
(
R"DOC(
CheckpointSave operator
...
...
@@ -150,30 +151,29 @@ to a file on disk.
"Delete the output dir if it exists."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"dir"
,
AddAttr
<
std
::
string
>
(
"dir"
,
"(string)"
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
"The dir
where the LoDTensor variables will be saved."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
}
};
class
CheckpointSaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Serial"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
CheckpointSaveOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
//
class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
//
public:
//
void operator()(const framework::OpDesc &op_desc,
//
framework::BlockDesc *block) const override {
//
auto out_var_name = op_desc.Output("Serial").front();
//
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
//
auto var_type = framework::proto::VarType::RAW;
//
out_var.SetType(var_type);
//
}
//
};
//
class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
//
public:
//
void operator()(framework::InferShapeContext *ctx) const override {}
//
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -181,7 +181,10 @@ class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
checkpoint_save
,
ops
::
CheckpointSaveOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CheckpointSaveOpProtoMaker
,
ops
::
CheckpointSaveOpVarTypeInference
,
ops
::
CheckpointSaveOpShapeInference
);
ops
::
CheckpointSaveOpProtoMaker
);
// REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
// paddle::framework::EmptyGradOpMaker,
// ops::CheckpointSaveOpProtoMaker,
// ops::CheckpointSaveOpVarTypeInference,
// ops::CheckpointSaveOpShapeInference);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录