Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e684575f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e684575f
编写于
6月 22, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint feature optimized
上级
2229db52
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
40 addition
and
31 deletion
+40
-31
paddle/fluid/operators/checkpoint_notify_op.cc
paddle/fluid/operators/checkpoint_notify_op.cc
+7
-6
paddle/fluid/operators/detail/macros.h
paddle/fluid/operators/detail/macros.h
+4
-0
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+4
-7
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+3
-2
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+6
-6
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+4
-2
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+5
-5
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+5
-2
未找到文件。
paddle/fluid/operators/checkpoint_notify_op.cc
浏览文件 @
e684575f
...
@@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase {
...
@@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase {
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"checkpoint notify sending "
<<
dir
<<
" to "
<<
epmap
[
i
];
auto
lookup_table_save_dir
=
auto
serial_looku_table
=
string
::
Sprintf
(
"%s/%s_%d"
,
dir
,
lookup_table_name
,
i
);
string
::
Sprintf
(
"%s/%s_%d"
,
dir
,
lookup_table_name
,
i
);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
serial_looku_table
);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
lookup_table_save_dir
);
VLOG
(
3
)
<<
"checkpoint notify sending lookup table: "
<<
lookup_table_name
<<
" and dir:"
<<
dir
<<
" to "
<<
epmap
[
i
];
}
}
rpc_client
->
Wait
();
rpc_client
->
Wait
();
}
}
...
@@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
string
>
(
"lookup_table"
,
AddAttr
<
std
::
string
>
(
"lookup_table"
,
"(string, default '') the lookup table name"
);
"(string, default '') the lookup table name"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Prefetch
operator
CheckpointNotify
operator
This operator will send
Ids variables
to listen_and_serve op at
This operator will send
lookup table and it's checkpoint direcoty
to listen_and_serve op at
the parameter server
and fetch result back
.
the parameter server.
)DOC"
);
)DOC"
);
}
}
};
};
...
...
paddle/fluid/operators/detail/macros.h
浏览文件 @
e684575f
...
@@ -25,3 +25,7 @@
...
@@ -25,3 +25,7 @@
#define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T distributed::BRPCClient
#define RPCCLIENT_T distributed::BRPCClient
#endif
#endif
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr
char
LOOKUP_TABLE_PATH
[]
=
"lookup_table_path"
;
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
e684575f
...
@@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase {
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()
,
true
));
request_handler
->
dev_ctx
()));
int
method_id
=
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
...
@@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase {
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_dir
=
request_
->
OutVarname
();
std
::
string
checkpoint_dir
=
request_
->
OutVarname
();
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
;
<<
", dir: "
<<
checkpoint_dir
;
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
invar
,
&
outva
r
,
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullpt
r
,
checkpoint_dir
);
checkpoint_dir
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
}
}
...
@@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
...
@@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return
;
return
;
}
}
LOG
(
INFO
)
<<
"TryToRegisterNewOne on RPC NAME: "
<<
rpc_name
VLOG
(
4
)
<<
"TryToRegisterNewOne on RPC NAME: "
<<
rpc_name
<<
" REQ ID: "
<<
req_id
;
<<
" REQ ID: "
<<
req_id
;
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
e684575f
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
...
@@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id
!=
-
1
,
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
auto
*
lt_var
=
scope
->
FindVar
(
"loopup_table_path"
)
->
GetMutable
<
std
::
string
>
();
auto
*
lt_var
=
scope
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update
loop
up_table_path to: "
VLOG
(
4
)
<<
"RequestCheckpointHandler update
var look
up_table_path to: "
<<
out_var_name
;
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
return
true
;
return
true
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
e684575f
...
@@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_notify_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
checkpoint_notify_
block_
id
=
Attr
<
int
>
(
kCheckpointBlockId
);
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", end_point:"
<<
endpoint
<<
", CheckpointNotify Id: "
<<
checkpoint_notify_id
;
<<
", CheckpointNotify Id: "
<<
checkpoint_notify_
block_
id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
@@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_
.
reset
(
request_prefetch_handler_
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync_mode
,
checkpoint_notify_id
));
sync_mode
,
checkpoint_notify_
block_
id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
());
request_send_handler_
.
get
());
...
@@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
if
(
checkpoint_notify_id
!=
-
1
)
{
if
(
checkpoint_notify_
block_
id
!=
-
1
)
{
auto
ctx
=
executor
.
Prepare
(
*
program
,
checkpoint_notify_id
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
checkpoint_notify_
block_
id
);
ckpt_pre_context
=
std
::
move
(
ctx
);
ckpt_pre_context
=
std
::
move
(
ctx
);
}
}
...
@@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
,
checkpoint_notify_id
);
checkpoint_notify_
block_
id
);
}
else
{
}
else
{
RunAsyncLoop
(
&
executor
,
program
);
RunAsyncLoop
(
&
executor
,
program
);
}
}
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
e684575f
...
@@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
...
@@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddOutput
(
"Out"
,
"The
tensor
need to be loaded"
);
AddOutput
(
"Out"
,
"The
LoDTensor / SelectedRows
need to be loaded"
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"load_as_fp16"
,
"load_as_fp16"
,
"If true, the tensor will be first loaded and then "
"If true, the tensor will be first loaded and then "
...
@@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")"
)
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddComment
(
"Load operator will load a tensor variable from disk file."
);
AddComment
(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
e684575f
...
@@ -24,6 +24,7 @@ limitations under the License. */
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase {
...
@@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase {
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
framework
::
Variable
*
var
)
const
{
auto
*
lt_var
=
auto
*
lt_var
=
scope
.
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
scope
.
FindVar
(
"loopup_table_path"
)
->
GetMutable
<
std
::
string
>
();
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
lt_var
!=
nullptr
,
lt_var
!=
nullptr
,
"Can not find variable loo
p
up_table_path for SaveSelectedRows"
);
"Can not find variable loo
k
up_table_path for SaveSelectedRows"
);
std
::
string
filename
=
lt_var
->
data
();
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
...
@@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Save operator
Save operator
This operator will serialize and write
a tensor/selected r
ows variable to file on disk.
This operator will serialize and write
LoDTensor / SelectedR
ows variable to file on disk.
)DOC"
);
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"(boolean, default true)"
...
@@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
...
@@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"loopup_table_path"
).
front
();
auto
out_var_name
=
op_desc
.
Output
(
LOOKUP_TABLE_PATH
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
out_var
.
SetType
(
var_type
);
...
...
python/paddle/fluid/io.py
浏览文件 @
e684575f
...
@@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
...
@@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
main_program(Program): Find the variable named table_name in main_program
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
table_name(str): lookup table name
Returns:
Returns:
None
None
...
@@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
...
@@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
"""
trainer will load some args from it's
independent directory,
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
such as epoch_id and step_id.
Args:
Args:
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
e684575f
...
@@ -914,7 +914,7 @@ class DistributeTranspiler(object):
...
@@ -914,7 +914,7 @@ class DistributeTranspiler(object):
import
os
import
os
pserver_program
.
global_block
().
create_var
(
pserver_program
.
global_block
().
create_var
(
name
=
"loo
p
up_table_path"
,
name
=
"loo
k
up_table_path"
,
persistable
=
True
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
core
.
VarDesc
.
VarType
.
RAW
)
...
@@ -923,7 +923,10 @@ class DistributeTranspiler(object):
...
@@ -923,7 +923,10 @@ class DistributeTranspiler(object):
type
=
'save'
,
type
=
'save'
,
inputs
=
{
'X'
:
[
self
.
table_name
]},
inputs
=
{
'X'
:
[
self
.
table_name
]},
outputs
=
{},
outputs
=
{},
attrs
=
{
'file_path'
:
self
.
table_name
})
attrs
=
{
'file_path'
:
"this 'file_path' do not be used in save lookup table variable"
})
return
checkpoint_save_block
.
idx
return
checkpoint_save_block
.
idx
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录