Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a4fd3756
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a4fd3756
编写于
5月 18, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
f9d4b9da
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
95 addition
and
50 deletion
+95
-50
paddle/fluid/operators/checkpoint_load_op.cc
paddle/fluid/operators/checkpoint_load_op.cc
+56
-29
paddle/fluid/operators/checkpoint_op_test.cc
paddle/fluid/operators/checkpoint_op_test.cc
+20
-4
paddle/fluid/operators/checkpoint_save_op.cc
paddle/fluid/operators/checkpoint_save_op.cc
+19
-17
未找到文件。
paddle/fluid/operators/checkpoint_load_op.cc
浏览文件 @
a4fd3756
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <streambuf>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
...
...
@@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) {
file_path
.
append
(
file_path
);
file_path
.
append
(
"/"
);
file_path
.
append
(
file
);
return
full_path
;
return
file_path
;
}
static
bool
IsNumber
(
const
std
::
string
&
s
)
{
std
::
string
::
const_iterator
it
=
s
.
begin
();
while
(
it
!=
s
.
end
()
&&
std
::
isdigit
(
*
it
))
++
it
;
return
!
s
.
empty
()
&&
it
==
s
.
end
();
}
static
void
LoadInputVars
(
const
framework
::
Scope
&
scope
,
...
...
@@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope,
"Cannot find variable %s for save_combine_op"
,
inp_var_names
[
i
]);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
"
Save
CombineOp only supports LoDTensor, %s has wrong type"
,
"
Load
CombineOp only supports LoDTensor, %s has wrong type"
,
inp_var_names
[
i
]);
std
::
string
var_file
=
GenePath
(
dir
,
inp_var_names
[
i
]);
...
...
@@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope,
static
void
LoadStringArgv
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
string
&
argv
,
const
std
::
string
&
dir
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
const
std
::
vector
<
std
::
string
>
&
argv
,
const
std
::
string
&
dir
)
{
for
(
size_t
i
=
0
;
i
<
argv
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
inp_var_names
[
i
]);
auto
*
var
=
scope
.
FindVar
(
argv
[
i
]);
std
::
string
*
var_str
=
var
->
GetMutable
<
std
::
string
>
();
std
::
string
var_file
=
GenePath
(
dir
,
argv
);
std
::
string
var_file
=
GenePath
(
dir
,
argv
[
i
]);
std
::
ifstream
fin
(
var_file
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
var_file
);
std
::
getline
(
fin
,
var_str
);
std
::
getline
(
fin
,
*
var_str
);
fin
.
close
();
VLOG
(
3
)
<<
" load String argv: "
<<
argv
<<
" value is: "
<<
var_str
;
VLOG
(
3
)
<<
" load String argv: "
<<
argv
[
i
]
<<
" value is: "
<<
var_str
;
}
}
...
...
@@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
std
::
string
serial_num
=
Attr
<
std
::
string
>
(
"Serial"
);
std
::
string
serial_num_attr
=
Attr
<
std
::
string
>
(
"Serial"
);
PADDLE_ENFORCE
(
IsNumber
(
serial_num_attr
),
"Checkpoint Serial must be a number"
);
std
::
string
serial_var_name
=
std
::
string
(
SERIAL_VAR
);
auto
*
serial_var
=
scope
.
FindVar
(
serial_var_name
);
if
(
serial_var
==
nullptr
)
{
*
serial_var
=
scope
.
Var
(
serial_var_name
);
auto
*
serial_tmp
=
serial_var
->
GetMutable
<
std
::
string
>
();
serial_tmp
->
append
(
"0"
);
}
PADDLE_ENFORCE
(
serial_var
!=
nullptr
,
"Cannot find variable %s for checkpoint_load_op"
,
serial_var_name
);
auto
*
serial_num
=
serial_var
->
GetMutable
<
std
::
string
>
();
VLOG
(
1
)
<<
"CheckpointLoadOp set "
<<
SERIAL_NUMBER
serial_num
=
serial_num_attr
;
VLOG
(
1
)
<<
"CheckpointLoadOp set "
<<
SERIAL_VAR
<<
" value: "
<<
serial_num
;
std
::
string
success
=
GenePath
(
dir
,
serial_num
);
std
::
string
success
=
GenePath
(
dir
,
serial_num
->
c_str
()
);
VLOG
(
3
)
<<
"Load checkpoint from dir: "
<<
success
;
success
=
GenePath
(
success
,
SUCCESS
);
bool
is_present
=
FileExists
(
success
);
...
...
@@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase {
auto
inp_var_names
=
Inputs
(
"X"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
inp_var_names
.
size
()),
0
,
"The number of input variables should be greater than 0"
);
LoadInputVars
(
scope
,
place
,
&
inp_var_names
);
LoadInputVars
(
scope
,
place
,
inp_var_names
,
dir
);
VLOG
(
3
)
<<
"Ready to load string argv to scope"
;
auto
argv
=
Inputs
(
"Argv"
);
LoadStringArgv
(
scope
,
place
,
&
argv
,
&
dir
);
//
VLOG(3) << "Ready to load string argv to scope";
// auto argv = Output
("Argv");
// LoadStringArgv(scope, place, argv,
dir);
}
};
...
...
@@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
Add
In
put
(
Add
Out
put
(
"Argv"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
"(vector) Input LoDTensors that need to be saved together in a file."
);
AddComment
(
R"DOC(
CheckpointLoad operator
This operator will serialize and write a list of input LoDTensor variables
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC"
);
...
...
@@ -177,10 +182,32 @@ to a file on disk.
}
};
class
CheckpointLoadOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Argv"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
CheckpointLoadOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
checkpoint_load
,
ops
::
CheckpointLoadOp
,
ops
::
CheckpointLoadOpProtoMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CheckpointLoadOpProtoMaker
,
ops
::
CheckpointLoadOpVarTypeInference
,
ops
::
CheckpointLoadOpShapeInference
);
// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
// ops::CheckpointLoadOpProtoMaker);
paddle/fluid/operators/checkpoint_op_test.cc
浏览文件 @
a4fd3756
...
...
@@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs
.
insert
({
"dir"
,
std
::
string
(
"ckpt"
)});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"checkpoint_save"
,
{{
"X"
,
{
"test_var"
}}},
attrs
);
"checkpoint_save"
,
{{
"X"
,
{
"test_var"
}}},
{},
attrs
);
save_op
->
Run
(
scope
,
place
);
}
...
...
@@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) {
paddle
::
framework
::
Scope
scope
;
paddle
::
platform
::
CPUPlace
place
;
scope
.
Var
(
"test_var"
);
auto
var
=
scope
.
Var
(
"test_var"
);
auto
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
tensor
->
Resize
({
3
,
10
});
paddle
::
framework
::
LoD
expect_lod
;
expect_lod
.
resize
(
1
);
expect_lod
[
0
].
push_back
(
0
);
expect_lod
[
0
].
push_back
(
1
);
expect_lod
[
0
].
push_back
(
2
);
expect_lod
[
0
].
push_back
(
3
);
tensor
->
set_lod
(
expect_lod
);
float
*
expect
=
tensor
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
tensor
->
numel
();
++
i
)
{
expect
[
i
]
=
static_cast
<
float
>
(
paddle
::
platform
::
float16
(
i
));
}
scope
.
Var
(
"SERIAL_NUMBER"
);
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"dir"
,
std
::
string
(
"ckpt"
)});
attrs
.
insert
({
"Serial"
,
std
::
string
(
"SERIAL_NUMBER"
)});
auto
load_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"checkpoint_load"
,
{{
"X"
,
{
"test_var"
}}},
{{
"Serial"
,
{
"SERIAL_NUMBER"
}}},
attrs
);
"checkpoint_load"
,
{{
"X"
,
{
"test_var"
}}},
{{
"Argv"
,
{}}},
attrs
);
load_op
->
Run
(
scope
,
place
);
}
paddle/fluid/operators/checkpoint_save_op.cc
浏览文件 @
a4fd3756
...
...
@@ -33,12 +33,18 @@ constexpr char kSEP = '/';
const
char
SUCCESS
[]
=
"_SUCCESS"
;
const
char
SERIAL_VAR
[]
=
"SERIAL_NUMBER"
;
static
bool
IsNumber
(
const
std
::
string
&
s
)
{
std
::
string
::
const_iterator
it
=
s
.
begin
();
while
(
it
!=
s
.
end
()
&&
std
::
isdigit
(
*
it
))
++
it
;
return
!
s
.
empty
()
&&
it
==
s
.
end
();
}
static
std
::
string
GenePath
(
const
std
::
string
&
dir
,
const
std
::
string
&
file
)
{
std
::
string
file_path
;
file_path
.
append
(
file_path
);
file_path
.
append
(
dir
);
file_path
.
append
(
"/"
);
file_path
.
append
(
file
);
return
f
ull
_path
;
return
f
ile
_path
;
}
static
bool
FileExists
(
const
std
::
string
&
filepath
)
{
...
...
@@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dir
=
Attr
<
std
::
string
>
(
"dir"
);
auto
ck_
dir
=
Attr
<
std
::
string
>
(
"dir"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
std
::
string
serial_var_name
=
std
::
string
(
SERIAL_VAR
);
auto
*
serial_var
=
scope
.
FindVar
(
serial_var_name
);
if
(
serial_var
==
nullptr
)
{
*
serial_var
=
scope
.
Var
(
serial_var_name
);
*
serial_tmp
=
serial_var
->
GetMutable
<
std
::
string
>
();
serial_tmp
->
append
(
"0"
);
}
auto
*
serial_num
=
serial_var
->
GetMutable
<
std
::
string
>
();
VLOG
(
1
)
<<
"CheckpointSaveOp get "
<<
SERIAL_NUMBER
auto
*
serial_num
=
scope
.
FindVar
(
serial_var_name
)
->
GetMutable
<
std
::
string
>
();
VLOG
(
1
)
<<
"CheckpointSaveOp get "
<<
SERIAL_VAR
<<
" value: "
<<
serial_num
;
auto
*
serial_num
=
serial_var
->
GetMutable
<
std
::
string
>
();
serial_num
->
append
(
"1"
);
if
(
!
IsNumber
(
serial_num
))
{
serial_num
=
"0"
;
}
dir
=
GenePath
(
dir
,
serial_num
);
std
::
string
dir
=
GenePath
(
ck_dir
,
serial_num
->
c_str
());
VLOG
(
1
)
<<
"CheckpointSaveOp current dir: "
<<
dir
;
bool
is_present
=
FileExists
(
dir
);
if
(
is_present
&&
!
overwrite
)
{
PADDLE_THROW
(
"%s exists!, checkpoint save cannot to
overwrite it"
,
dir
,
PADDLE_THROW
(
"%s exists!, checkpoint save cannot to overwrite it"
,
dir
,
overwrite
);
}
MkDirRecursively
(
dir
.
c_str
());
...
...
@@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
CheckpointSave operator
This operator will serialize and write a list of input LoDTensor variables
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录