Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bed0ecf3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bed0ecf3
编写于
3月 20, 2019
作者:
L
lujun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint pr be moved here, test=develop
上级
5bb04ea4
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
691 addition
and
226 deletion
+691
-226
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+83
-63
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+57
-37
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+58
-51
paddle/fluid/operators/save_load_combine_op_test.cc
paddle/fluid/operators/save_load_combine_op_test.cc
+2
-2
paddle/fluid/operators/save_load_op_test.cc
paddle/fluid/operators/save_load_op_test.cc
+2
-2
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+97
-67
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+3
-4
python/paddle/fluid/imperative/__init__.py
python/paddle/fluid/imperative/__init__.py
+4
-0
python/paddle/fluid/imperative/checkpoint.py
python/paddle/fluid/imperative/checkpoint.py
+194
-0
python/paddle/fluid/imperative/layers.py
python/paddle/fluid/imperative/layers.py
+28
-0
python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py
...addle/fluid/tests/unittests/test_imperative_checkpoint.py
+163
-0
未找到文件。
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
bed0ecf3
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -19,21 +20,71 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
LoadCombineOp
:
public
framework
::
Operator
Base
{
class
LoadCombineOp
:
public
framework
::
Operator
WithKernel
{
public:
LoadCombineOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
model_from_memory
=
Attr
<
bool
>
(
"model_from_memory"
);
auto
out_var_names
=
Outputs
(
"Out"
);
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
platform
::
CPUPlace
());
return
kt
;
}
};
class
LoadCombineOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"(vector) The output LoDTensors that will be read from the input file."
)
.
AsDuplicable
();
AddAttr
<
bool
>
(
"load_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string) "
"LoDTensors will be loaded from
\"
file_path
\"
."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
bool
>
(
"model_from_memory"
,
"(boolean, default false)"
"If true, file_path is in memory, and LoDTensors will be "
"loaded directly from memory"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
with the SaveCombine operator, and can only deserialize one or more LoDTensors
that were saved using the SaveCombine operator.
)DOC"
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
LoadCombineOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
auto
load_as_fp16
=
ctx
.
Attr
<
bool
>
(
"load_as_fp16"
);
auto
model_from_memory
=
ctx
.
Attr
<
bool
>
(
"model_from_memory"
);
auto
&
out_var_names
=
ctx
.
Outputs
(
"Out"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
out_var_names
.
size
()),
0
,
"The number of output variables should be greater than 0."
);
...
...
@@ -41,27 +92,27 @@ class LoadCombineOp : public framework::OperatorBase {
std
::
ifstream
fin
(
filename
,
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load_combine op"
,
filename
);
LoadParamsFromBuffer
(
scope
,
place
,
&
fin
,
load_as_fp16
,
out_var_names
);
LoadParamsFromBuffer
(
ctx
,
place
,
&
fin
,
load_as_fp16
,
out_var_names
);
}
else
{
PADDLE_ENFORCE
(
!
filename
.
empty
(),
"Cannot load file from memory"
);
std
::
stringstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
LoadParamsFromBuffer
(
scope
,
place
,
&
fin
,
load_as_fp16
,
out_var_names
);
LoadParamsFromBuffer
(
ctx
,
place
,
&
fin
,
load_as_fp16
,
out_var_names
);
}
}
void
LoadParamsFromBuffer
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
framework
::
ExecutionContext
&
context
,
const
platform
::
Place
&
place
,
std
::
istream
*
buffer
,
bool
load_as_fp16
,
const
std
::
vector
<
std
::
string
>
&
out_var_names
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
out_vars
=
context
.
MultiOutputVar
(
"Out"
);
for
(
size_t
i
=
0
;
i
<
out_var_names
.
size
();
i
++
)
{
auto
*
out_var
=
scope
.
FindVar
(
out_var_names
[
i
]);
PADDLE_ENFORCE
(
out_vars
[
i
]
!=
nullptr
,
"Output variable %s cannot be found"
,
out_var_names
[
i
]);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found"
,
out_var_names
[
i
]);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
out_vars
[
i
]
->
GetMutable
<
framework
::
LoDTensor
>
();
// Error checking
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
*
buffer
),
"Cannot read more"
);
...
...
@@ -84,8 +135,8 @@ class LoadCombineOp : public framework::OperatorBase {
&
fp16_tensor
);
// reset output tensor
out_var
->
Clear
();
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_var
s
[
i
]
->
Clear
();
tensor
=
out_var
s
[
i
]
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
set_lod
(
fp16_tensor
.
lod
());
tensor
->
ShareDataWith
(
fp16_tensor
);
}
...
...
@@ -97,48 +148,17 @@ class LoadCombineOp : public framework::OperatorBase {
}
};
class
LoadCombineOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"(vector) The output LoDTensors that will be read from the input file."
)
.
AsDuplicable
();
AddAttr
<
bool
>
(
"load_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string) "
"LoDTensors will be loaded from
\"
file_path
\"
."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
bool
>
(
"model_from_memory"
,
"(boolean, default false)"
"If true, file_path is in memory, and LoDTensors will be "
"loaded directly from memory"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
with the SaveCombine operator, and can only deserialize one or more LoDTensors
that were saved using the SaveCombine operator.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
load_combine
,
ops
::
LoadCombineOp
,
ops
::
LoadCombineOpProtoMaker
);
REGISTER_OP_CPU_KERNEL
(
load_combine
,
ops
::
LoadCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LoadCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
LoadCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
LoadCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/load_op.cc
浏览文件 @
bed0ecf3
...
...
@@ -21,31 +21,63 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
LoadOp
:
public
framework
::
Operator
Base
{
class
LoadOp
:
public
framework
::
Operator
WithKernel
{
public:
LoadOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
platform
::
CPUPlace
());
return
kt
;
}
};
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"The LoDTensor / SelectedRows need to be loaded"
);
AddAttr
<
bool
>
(
"load_as_fp16"
,
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion. Default is false."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddComment
(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
LoadOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
,
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
filename
);
auto
out_var_name
=
Output
(
"Out"
);
auto
*
out_var
=
scope
.
FindVar
(
out_var_name
);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found in scope %p"
,
out_var_name
,
&
scope
);
auto
out_var_name
=
ctx
.
Outputs
(
"Out"
).
data
();
auto
*
out_var
=
ctx
.
OutputVar
(
"Out"
);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found "
,
out_var_name
);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable cannot be found "
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
LoadLodTensor
(
fin
,
place
,
out_var
);
LoadLodTensor
(
fin
,
place
,
out_var
,
ctx
);
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
LoadSelectedRows
(
fin
,
place
,
out_var
);
}
else
{
...
...
@@ -57,14 +89,15 @@ class LoadOp : public framework::OperatorBase {
}
void
LoadLodTensor
(
std
::
istream
&
fin
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
framework
::
Variable
*
var
,
const
framework
::
ExecutionContext
&
ctx
)
const
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
load_as_fp16
=
ctx
.
Attr
<
bool
>
(
"load_as_fp16"
);
auto
in_dtype
=
tensor
->
type
();
auto
out_dtype
=
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
...
@@ -97,27 +130,14 @@ class LoadOp : public framework::OperatorBase {
}
};
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"The LoDTensor / SelectedRows need to be loaded"
);
AddAttr
<
bool
>
(
"load_as_fp16"
,
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion. Default is false."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddComment
(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
load
,
ops
::
LoadOp
,
ops
::
LoadOpProtoMaker
);
REGISTER_OP_CPU_KERNEL
(
load
,
ops
::
LoadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LoadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
LoadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
LoadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
bed0ecf3
...
...
@@ -27,20 +27,53 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
SaveCombineOp
:
public
framework
::
Operator
Base
{
class
SaveCombineOp
:
public
framework
::
Operator
WithKernel
{
public:
SaveCombineOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
class
SaveCombineOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
SaveCombine operator
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"Overwrite the output file if it exists."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"save_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string)"
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SaveCombineOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
ctx
.
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
ctx
.
Attr
<
bool
>
(
"save_as_fp16"
);
bool
is_present
=
FileExists
(
filename
);
if
(
is_present
&&
!
overwrite
)
{
...
...
@@ -53,7 +86,8 @@ class SaveCombineOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
inp_var_names
=
Inputs
(
"X"
);
auto
&
inp_var_names
=
ctx
.
Inputs
(
"X"
);
auto
&
inp_vars
=
ctx
.
MultiInputVar
(
"X"
);
PADDLE_ENFORCE_GT
(
static_cast
<
int
>
(
inp_var_names
.
size
()),
0
,
"The number of input variables should be greater than 0"
);
...
...
@@ -62,16 +96,14 @@ class SaveCombineOp : public framework::OperatorBase {
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
inp_var_names
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
inp_var_names
[
i
]);
PADDLE_ENFORCE
(
var
!=
nullptr
,
PADDLE_ENFORCE
(
inp_vars
[
i
]
!=
nullptr
,
"Cannot find variable %s for save_combine_op"
,
inp_var_names
[
i
]);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
PADDLE_ENFORCE
(
inp_vars
[
i
]
->
IsType
<
framework
::
LoDTensor
>
(),
"SaveCombineOp only supports LoDTensor, %s has wrong type"
,
inp_var_names
[
i
]);
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
tensor
=
inp_vars
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
...
...
@@ -95,38 +127,6 @@ class SaveCombineOp : public framework::OperatorBase {
}
};
class
SaveCombineOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(vector) Input LoDTensors that need to be saved together in a file."
)
.
AsDuplicable
();
AddComment
(
R"DOC(
SaveCombine operator
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"Overwrite the output file if it exists."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"save_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string)"
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -134,3 +134,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
save_combine
,
ops
::
SaveCombineOp
,
ops
::
SaveCombineOpProtoMaker
);
REGISTER_OP_CPU_KERNEL
(
save_combine
,
ops
::
SaveCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SaveCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SaveCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SaveCombineOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/save_load_combine_op_test.cc
浏览文件 @
bed0ecf3
...
...
@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_
NO_KERNEL
_OP
(
save_combine
);
USE_
NO_KERNEL
_OP
(
load_combine
);
USE_
CPU_ONLY
_OP
(
save_combine
);
USE_
CPU_ONLY
_OP
(
load_combine
);
template
<
typename
T
,
typename
U
>
T
*
CreateForSaveCombineOp
(
int
x
,
int
y
,
const
std
::
vector
<
int
>&
lod_info
,
...
...
paddle/fluid/operators/save_load_op_test.cc
浏览文件 @
bed0ecf3
...
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_
NO_KERNEL
_OP
(
save
);
USE_
NO_KERNEL
_OP
(
load
);
USE_
CPU_ONLY
_OP
(
save
);
USE_
CPU_ONLY
_OP
(
load
);
TEST
(
SaveLoadOp
,
CPU
)
{
paddle
::
framework
::
Scope
scope
;
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
bed0ecf3
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
...
...
@@ -29,29 +30,88 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr
char
LOOKUP_TABLE_PATH
[]
=
"kLookupTablePath"
;
class
SaveOp
:
public
framework
::
Operator
Base
{
class
SaveOp
:
public
framework
::
Operator
WithKernel
{
public:
SaveOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
iname
=
Input
(
"X"
);
auto
*
var
=
scope
.
FindVar
(
iname
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s for save_op"
,
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
());
}
};
class
SaveOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor ) Input LoDTensor and SelectedRows to be saved"
);
AddComment
(
R"DOC(
Save operator
This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"Overwrite the output file if exist"
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"save_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string)"
"The
\"
file_path
\"
where the variable will be saved."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddOutput
(
LOOKUP_TABLE_PATH
,
"(string)"
"for pserver: The
\"
kLookupTablePath
\"
where checkpoint notify "
"to save lookup table variables"
" to directory specified."
)
.
AsDispensable
();
}
};
class
SaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
ctx
->
SetType
(
LOOKUP_TABLE_PATH
,
var_type
);
}
};
class
SaveOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SaveOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
auto
*
input_var
=
ctx
.
InputVar
(
"X"
);
auto
iname
=
ctx
.
Inputs
(
"X"
).
data
();
PADDLE_ENFORCE
(
input_var
!=
nullptr
,
"Cannot find variable %s for save_op"
,
iname
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SaveLodTensor
(
place
,
var
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
SaveSelectedRows
(
scope
,
place
,
var
);
if
(
input_
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SaveLodTensor
(
ctx
,
place
,
input_
var
);
}
else
if
(
input_
var
->
IsType
<
framework
::
SelectedRows
>
())
{
SaveSelectedRows
(
ctx
,
place
,
input_
var
);
}
else
{
PADDLE_ENFORCE
(
false
,
...
...
@@ -60,10 +120,11 @@ class SaveOp : public framework::OperatorBase {
}
}
void
SaveLodTensor
(
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
void
SaveLodTensor
(
const
framework
::
ExecutionContext
&
ctx
,
const
platform
::
Place
&
place
,
const
framework
::
Variable
*
var
)
const
{
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
ctx
.
Attr
<
bool
>
(
"overwrite"
);
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
...
...
@@ -84,7 +145,7 @@ class SaveOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_as_fp16
=
ctx
.
Attr
<
bool
>
(
"save_as_fp16"
);
auto
in_dtype
=
tensor
.
type
();
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
...
@@ -102,13 +163,15 @@ class SaveOp : public framework::OperatorBase {
fout
.
close
();
}
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
void
SaveSelectedRows
(
const
framework
::
ExecutionContext
&
ctx
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
auto
*
lt_var
=
scope
.
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
(
);
const
framework
::
Variable
*
var
)
const
{
framework
::
Variable
*
out_put_var
=
ctx
.
OutputVar
(
LOOKUP_TABLE_PATH
);
PADDLE_ENFORCE
(
l
t_var
!=
nullptr
,
out_pu
t_var
!=
nullptr
,
"Can not find variable kLookupTablePath for SaveSelectedRows"
);
auto
*
lt_var
=
out_put_var
->
GetMutable
<
std
::
string
>
();
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
...
...
@@ -130,50 +193,17 @@ class SaveOp : public framework::OperatorBase {
}
};
class
SaveOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor ) Input LoDTensor and SelectedRows to be saved"
);
AddComment
(
R"DOC(
Save operator
This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"Overwrite the output file if exist"
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"save_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string)"
"The
\"
file_path
\"
where the variable will be saved."
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
}
};
class
SaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
LOOKUP_TABLE_PATH
).
front
();
ctx
->
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
}
};
class
SaveOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
save
,
ops
::
SaveOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SaveOpProtoMaker
,
ops
::
SaveOpVarTypeInference
,
ops
::
SaveOpShapeInference
);
REGISTER_OPERATOR
(
save
,
ops
::
SaveOp
,
ops
::
SaveOpProtoMaker
,
ops
::
SaveOpVarTypeInference
,
ops
::
SaveOpShapeInference
);
REGISTER_OP_CPU_KERNEL
(
save
,
ops
::
SaveOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SaveOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SaveOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SaveOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int8_t
>
,
ops
::
SaveOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
python/paddle/fluid/framework.py
浏览文件 @
bed0ecf3
...
...
@@ -644,10 +644,9 @@ class Operator(object):
outputs={"Out": [var1]})
"""
OP_WITHOUT_KERNEL_SET
=
{
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'go'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'save_combine'
,
'load_combine'
,
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
'feed'
,
'fetch'
,
'recurrent'
,
'go'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'gen_nccl_id'
}
def
__init__
(
self
,
...
...
python/paddle/fluid/imperative/__init__.py
浏览文件 @
bed0ecf3
...
...
@@ -29,9 +29,13 @@ from .tracer import *
from
.
import
profiler
from
.profiler
import
*
from
.
import
checkpoint
from
.checkpoint
import
*
__all__
=
[]
__all__
+=
layers
.
__all__
__all__
+=
base
.
__all__
__all__
+=
nn
.
__all__
__all__
+=
tracer
.
__all__
__all__
+=
profiler
.
__all__
__all__
+=
checkpoint
.
__all__
python/paddle/fluid/imperative/checkpoint.py
0 → 100644
浏览文件 @
bed0ecf3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
collections
from
..
import
core
from
..framework
import
Variable
,
Parameter
,
default_main_program
from
.layers
import
Layer
__all__
=
[
'save_persistables'
,
'load_persistables'
]
def
save_persistables
(
obj
,
dirname
,
filename
=
None
):
"""
This function filters out all variables in layer.parameters from the
give `layer` and then trys to load these variables from the folder
`dirname` or the file `filename`.
Use the `dirname` to specify the folder where persistable variables were
saved. If variables were saved in separate files, set `filename` None;
if all variables were saved in a single file, use `filename` to specify
the file name.
Args:
var_list(dict of Parameters|Layer): The parameters will
be saved. If it is None, nothing
will be deal.
dirname(str): The directory path.
filename(str|None): The file which saved all variables. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
Examples:
.. code-block:: python
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
x = to_variable(x_data)
y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
param_path = "./my_paddle_model"
fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path,
layer=ptb_model)
"""
if
isinstance
(
obj
,
collections
.
OrderedDict
):
_save_var_to_file
(
obj
,
dirname
,
filename
)
elif
isinstance
(
obj
,
Layer
):
_save_var_to_file
(
obj
.
state_dict
(
include_sublayers
=
True
),
dirname
,
filename
)
def
load_persistables
(
obj
,
dirname
,
filename
=
None
):
"""
This function trys to load persistable variables from the folder
`dirname` or the file `filename`.
Use the `dirname` to specify the folder where persistable variables were
saved. If variables were saved in separate files, set `filename` None;
if all variables were saved in a single file, use `filename` to specify
the file name.
Args:
obj(dict of Parameters|Layer): The parameters will be loaded.
dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
dict: The parameter-dict resumed from file
Examples:
.. code-block:: python
my_layer = layer(fluid.imperative.Layer)
param_path = "./my_paddle_model"
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.parameters(), param_path)
param_1 = param_dict['PtbModel_0.w_1']
or:
my_layer = layer(fluid.imperative.Layer)
param_path = "./my_paddle_model"
filename = "model.file"
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path,
filename=filename)
param_1 = param_dict['PtbModel_0.w_1']
"""
if
isinstance
(
obj
,
collections
.
OrderedDict
):
return
_load_var_from_file
(
obj
,
dirname
,
filename
)
elif
isinstance
(
obj
,
Layer
):
return
_load_var_from_file
(
obj
.
state_dict
(
include_sublayers
=
True
),
dirname
,
filename
)
return
{}
def
_save_var_to_file
(
stat_dict
,
file_dir
,
file_name
):
save_block
=
default_main_program
().
global_block
()
save_var_map
=
{}
for
each_var
in
stat_dict
.
items
():
save_var_map
[
each_var
.
name
]
=
each_var
if
file_name
is
None
:
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
each_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
each_var
.
name
)})
if
file_name
is
not
None
:
save_var_list
=
[]
for
name
in
sorted
(
save_var_map
.
keys
()):
save_var_list
.
append
(
save_var_map
[
name
])
save_block
.
append_op
(
type
=
'save_combine'
,
inputs
=
{
'X'
:
save_var_list
},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
file_name
)})
def
_load_var_from_file
(
stat_dict
,
file_dir
,
file_name
):
load_block
=
default_main_program
().
global_block
()
load_var_map
=
{}
for
each_var
in
stat_dict
.
items
():
assert
isinstance
(
each_var
,
Variable
)
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
file_name
is
None
:
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
each_var
.
name
)})
load_var_map
[
new_var
.
name
]
=
new_var
if
file_name
is
not
None
:
load_var_list
=
[]
for
name
in
sorted
(
load_var_map
.
keys
()):
load_var_list
.
append
(
load_var_map
[
name
])
load_block
.
append_op
(
type
=
'load_combine'
,
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
file_name
)})
for
res_var
in
load_var_list
:
load_var_map
[
res_var
.
name
]
=
res_var
return
load_var_map
def
_clone_var_in_block_
(
block
,
var
):
assert
isinstance
(
var
,
Variable
)
return
block
.
create_var
(
name
=
var
.
name
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
lod_level
=
var
.
lod_level
,
persistable
=
True
)
python/paddle/fluid/imperative/layers.py
浏览文件 @
bed0ecf3
...
...
@@ -212,6 +212,34 @@ class Layer(core.Layer):
else
:
object
.
__delattr__
(
self
,
name
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
include_sublayers
=
True
):
if
destination
is
None
:
destination
=
collections
.
OrderedDict
()
for
name
,
data
in
self
.
_parameters
.
items
():
if
data
is
not
None
:
destination
[
prefix
+
name
]
=
data
if
include_sublayers
:
for
layer_name
,
layer_item
in
self
.
_sub_layers
.
items
():
if
layer_item
is
not
None
:
destination_temp
=
destination
.
copy
()
destination_temp
.
update
(
layer_item
.
state_dict
(
destination_temp
,
prefix
+
layer_name
+
"."
,
include_sublayers
))
destination
=
destination_temp
return
destination
def
load_dict
(
self
,
stat_dict
,
include_sublayers
=
True
):
for
name
,
item
in
self
.
__dict__
.
get
(
'_parameters'
,
None
).
items
():
if
item
.
name
in
stat_dict
:
self
.
__setattr__
(
name
,
stat_dict
[
item
.
name
])
if
include_sublayers
:
for
layer_name
,
layer_item
in
self
.
_sub_layers
.
items
():
if
layer_item
is
not
None
:
layer_item
.
load_dict
(
stat_dict
)
class
PyLayer
(
core
.
PyLayer
):
"""Layers composed of user-defined python codes."""
...
...
python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py
0 → 100644
浏览文件 @
bed0ecf3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.imperative.nn
import
Conv2D
,
Pool2D
,
FC
from
paddle.fluid.imperative.base
import
to_variable
class
SimpleImgConvPool
(
fluid
.
imperative
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
(
name_scope
)
self
.
_conv2d
=
Conv2D
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
self
.
full_name
(),
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
fluid
.
imperative
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
MNIST
,
self
).
__init__
(
name_scope
)
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
self
.
full_name
(),
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
self
.
full_name
(),
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
4
*
4
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
self
.
full_name
(),
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
class
TestImperativeCheckpoint
(
unittest
.
TestCase
):
def
save_load_persistables
(
self
):
seed
=
90
epoch_num
=
1
with
fluid
.
imperative
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
mnist
=
MNIST
(
"mnist"
)
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
dy_param_init_value
=
{}
step
=
0
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
dy_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
128
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
_stop_gradient
=
True
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
dy_out
=
avg_loss
.
_numpy
()
avg_loss
.
_backward
()
sgd
.
minimize
(
avg_loss
)
fluid
.
imperative
.
save_persistables
(
mnist
,
"save_dir"
)
mnist
.
clear_gradients
()
for
param
in
mnist
.
parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
_numpy
()
mnist
.
load_dict
(
fluid
.
imperative
.
load_persistables
(
mnist
,
"save_dir"
))
restore
=
mnist
.
parameters
()
self
.
assertEqual
(
len
(
dy_param_init_value
),
len
(
restore
))
for
value
in
restore
:
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_init_value
[
value
.
name
]))
self
.
assertTrue
(
np
.
isfinite
(
value
.
all
()))
self
.
assertFalse
(
np
.
isnan
(
value
.
any
()))
step
+=
1
if
step
>
20
:
break
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录