Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41701969
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41701969
编写于
6月 13, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[wip] ckpt m2 develop
上级
431491a2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
81 addition
and
10 deletion
+81
-10
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+1
-0
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+10
-0
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+8
-0
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+2
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+40
-10
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+20
-0
未找到文件。
paddle/fluid/operators/detail/request_handler.h
浏览文件 @
41701969
...
...
@@ -36,6 +36,7 @@ namespace detail {
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
...
...
paddle/fluid/operators/detail/request_handler_impl.h
浏览文件 @
41701969
...
...
@@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler {
const
std
::
string
&
out_var_name
=
""
)
override
;
};
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
public:
explicit
RequestCheckpointHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestCheckpointHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
=
""
)
override
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
41701969
...
...
@@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
detail
::
RequestCheckpointHandler
(
sync_mode
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
optimize_block
->
Program
();
...
...
@@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
...
@@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
kCheckpointBlockId
,
"BolckID to run save checkpoint on pserer."
)
.
SetDefault
(
-
1
);
}
};
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
41701969
...
...
@@ -32,6 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
);
...
...
@@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase {
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_checkpoint_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
41701969
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
...
...
@@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase {
MkDirRecursively
(
DirName
(
filename
).
c_str
());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
iname
=
Input
(
"X"
);
auto
*
var
=
scope
.
FindVar
(
iname
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s for save_op"
,
iname
);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
"SaveOp only support LoDTensor, %s has wrong type"
,
iname
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SaveLodTensor
(
filename
,
place
,
var
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
SaveSelectedRows
(
filename
,
place
,
var
);
}
else
{
PADDLE_ENFORCE
(
false
,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type"
,
iname
);
}
}
SaveLodTensor
(
const
string
&
filename
,
const
platform
::
Place
&
place
,
Variable
*
var
)
{
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
());
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
...
@@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase {
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
}
fout
.
close
()
}
SaveSelectedRows
(
const
string
&
filename
,
const
platform
::
Place
&
place
,
Variable
*
var
)
{
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
framework
::
SerializeToStream
(
fout
,
selectedRows
,
dev_ctx
);
fout
.
close
()
}
};
class
SaveOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor ) Input
tensor
to be saved"
);
AddInput
(
"X"
,
"(Tensor ) Input
LoDTensor and SelectedRows
to be saved"
);
AddComment
(
R"DOC(
Save operator
This operator will serialize and write a tensor variable to file on disk.
This operator will serialize and write a tensor
/selected rows
variable to file on disk.
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
41701969
...
...
@@ -522,6 +522,8 @@ class DistributeTranspiler:
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
)
prefetch_var_name_to_block_id
=
self
.
_create_prefetch_block
(
pserver_index
,
pserver_program
,
table_opt_block
)
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
pserver_program
,
table_opt_block
.
idx
)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
...
...
@@ -540,6 +542,7 @@ class DistributeTranspiler:
if
len
(
prefetch_var_name_to_block_id
)
>
0
:
attrs
[
'prefetch_var_name_to_block_id'
]
\
=
prefetch_var_name_to_block_id
attrs
[
'checkpint_block_id'
]
=
checkpoint_block_id
# step5 append the listen_and_serv op
pserver_program
.
global_block
().
append_op
(
...
...
@@ -824,6 +827,23 @@ class DistributeTranspiler:
return
table_opt_block
def
_create_checkpoint_save_block
(
self
,
pserver_program
,
pre_block_idx
):
"""
create a new block to handle save checkpoint.
"""
import
os
checkpoint_save_block
=
pserver_program
.
create_block
(
pre_block_idx
)
checkpoint_save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
self
.
table_name
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
"/tmp/pserver_ckpt/"
,
self
.
table_name
)
})
return
checkpoint_save_block
.
idx
def
_create_vars_from_blocklist
(
self
,
program
,
block_list
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录