Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b20fa022
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b20fa022
编写于
6月 26, 2018
作者:
T
tangwei12
提交者:
GitHub
6月 26, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11490 from seiriosPlus/ckpt_m2
Checkpoint M2: lookup table checkpoint
上级
bb29800a
f57978e6
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
617 addition
and
86 deletion
+617
-86
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/checkpoint_notify_op.cc
paddle/fluid/operators/checkpoint_notify_op.cc
+88
-0
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+17
-0
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+17
-0
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+45
-3
paddle/fluid/operators/distributed/grpc_service.h
paddle/fluid/operators/distributed/grpc_service.h
+4
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+11
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+23
-0
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+15
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+4
-0
paddle/fluid/operators/distributed/send_recv.proto
paddle/fluid/operators/distributed/send_recv.proto
+2
-0
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+30
-8
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+5
-1
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+35
-6
paddle/fluid/operators/save_load_op_test.cc
paddle/fluid/operators/save_load_op_test.cc
+1
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+78
-18
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+5
-2
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+182
-20
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+29
-25
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+24
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
b20fa022
...
@@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
...
@@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
endif
()
endif
()
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
foreach
(
dist_op
"prefetch_op"
"listen_and_serv_op"
"send_op"
"recv_op"
"send_barrier_op"
"fetch_barrier_op"
)
foreach
(
dist_op
"prefetch_op"
"
checkpoint_notify_op"
"
listen_and_serv_op"
"send_op"
"recv_op"
"send_barrier_op"
"fetch_barrier_op"
)
op_library
(
${
dist_op
}
DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
${
dist_op
}
DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
${
dist_op
}
.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
${
dist_op
}
.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
endforeach
()
endforeach
()
...
@@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
...
@@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
endif
()
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
checkpoint_notify_op
prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op
)
endif
()
endif
()
op_library
(
cross_entropy_op DEPS cross_entropy
)
op_library
(
cross_entropy_op DEPS cross_entropy
)
...
...
paddle/fluid/operators/checkpoint_notify_op.cc
0 → 100644
浏览文件 @
b20fa022
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
operators
{
class
CheckpointNotifyOp
:
public
framework
::
OperatorBase
{
public:
CheckpointNotifyOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
std
::
string
lookup_table_name
=
Attr
<
std
::
string
>
(
"lookup_table"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
auto
lookup_table_save_dir
=
string
::
Sprintf
(
"%s/%s_%d"
,
dir
,
lookup_table_name
,
i
);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
lookup_table_save_dir
);
VLOG
(
3
)
<<
"checkpoint notify sending lookup table: "
<<
lookup_table_name
<<
" and dir:"
<<
dir
<<
" to "
<<
epmap
[
i
];
}
rpc_client
->
Wait
();
}
};
class
CheckpointNotifyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order"
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
std
::
string
>
(
"dir"
,
"(string, default '') indicate the folder checkpoint will use"
);
AddAttr
<
std
::
string
>
(
"lookup_table"
,
"(string, default '') the lookup table name"
);
AddComment
(
R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC"
);
}
};
class
CheckpointNotifyOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
checkpoint_notify
,
ops
::
CheckpointNotifyOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CheckpointNotifyOpMaker
,
ops
::
CheckpointNotifyOpShapeInference
);
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
b20fa022
...
@@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
...
@@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
void
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
CheckpointNotifyProcessor
*
s
=
new
CheckpointNotifyProcessor
(
ch
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
CHECKPOINT_SAVE_MESSAGE
);
req
.
set_out_varname
(
dir
);
auto
rpc
=
s
->
stub_
->
AsyncCheckpointNotify
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
}
void
GRPCClient
::
Wait
()
{
void
GRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
b20fa022
...
@@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
CheckpointNotifyProcessor
:
public
BaseProcessor
{
public:
explicit
CheckpointNotifyProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
}
virtual
~
CheckpointNotifyProcessor
()
{}
virtual
void
Process
()
{}
sendrecv
::
VoidMessage
reply_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
class
GRPCClient
:
public
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
GRPCClient
()
{}
GRPCClient
()
{}
...
@@ -197,6 +211,9 @@ class GRPCClient : public RPCClient {
...
@@ -197,6 +211,9 @@ class GRPCClient : public RPCClient {
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_grpc_deadline
)
override
;
void
Wait
()
override
;
void
Wait
()
override
;
void
SendComplete
()
override
;
void
SendComplete
()
override
;
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
b20fa022
...
@@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
...
@@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
framework
::
Scope
*
local_scope_
;
framework
::
Scope
*
local_scope_
;
};
};
class
RequestCheckpointNotify
final
:
public
RequestBase
{
public:
explicit
RequestCheckpointNotify
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestCheckpointNotify
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_dir
=
request_
->
OutVarname
();
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
;
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
checkpoint_dir
);
Finish
(
reply_
,
&
responder_
);
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
void
AsyncGRPCServer
::
WaitServerReady
()
{
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
4
)
<<
"AsyncGRPCServer is wait server ready"
;
VLOG
(
4
)
<<
"AsyncGRPCServer is wait server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
...
@@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
...
@@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
reqs
.
reserve
(
kRequestBufSize
);
reqs
.
reserve
(
kRequestBufSize
);
for
(
int
i
=
0
;
i
<
kRequestBufSize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kRequestBufSize
;
i
++
)
{
VLOG
(
6
)
<<
"TryToRegisterNewOne on RPC NAME: "
<<
rpc_name
<<
" I: "
<<
i
;
TryToRegisterNewOne
(
rpc_name
,
i
);
TryToRegisterNewOne
(
rpc_name
,
i
);
}
}
...
@@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
...
@@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return
;
return
;
}
}
VLOG
(
4
)
<<
"
register send rpc_name:
"
<<
rpc_name
VLOG
(
4
)
<<
"
TryToRegisterNewOne on RPC NAME:
"
<<
rpc_name
<<
"
, handler:"
<<
rpc_call_map_
[
kRequestSend
]
;
<<
"
REQ ID: "
<<
req_id
;
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
...
@@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
...
@@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestCheckpoint
)
{
b
=
new
RequestCheckpointNotify
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
}
}
...
@@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
...
@@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
while
(
true
)
{
while
(
true
)
{
VLOG
(
4
)
<<
"HandleRequest "
<<
rpc_name
<<
" wait next"
;
VLOG
(
4
)
<<
"HandleRequest "
<<
rpc_name
<<
" wait next"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
"CompletionQueue "
<<
rpc_name
<<
" shutdown!"
;
VLOG
(
3
)
<<
"CompletionQueue "
<<
rpc_name
<<
" shutdown!"
;
break
;
break
;
}
}
...
...
paddle/fluid/operators/distributed/grpc_service.h
浏览文件 @
b20fa022
...
@@ -80,10 +80,11 @@ enum class GrpcMethod {
...
@@ -80,10 +80,11 @@ enum class GrpcMethod {
kSendVariable
,
kSendVariable
,
kGetVariable
,
kGetVariable
,
kPrefetchVariable
,
kPrefetchVariable
,
kCheckpointNotify
,
};
};
static
const
int
kGrpcNumMethods
=
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
k
PrefetchVariable
)
+
1
;
static_cast
<
int
>
(
GrpcMethod
::
k
CheckpointNotify
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
switch
(
id
)
{
...
@@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
...
@@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/GetVariable"
;
return
"/sendrecv.SendRecvService/GetVariable"
;
case
GrpcMethod
::
kPrefetchVariable
:
case
GrpcMethod
::
kPrefetchVariable
:
return
"/sendrecv.SendRecvService/PrefetchVariable"
;
return
"/sendrecv.SendRecvService/PrefetchVariable"
;
case
GrpcMethod
::
kCheckpointNotify
:
return
"/sendrecv.SendRecvService/CheckpointNotify"
;
}
}
// Shouldn't be reached.
// Shouldn't be reached.
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
b20fa022
...
@@ -36,12 +36,16 @@ namespace distributed {
...
@@ -36,12 +36,16 @@ namespace distributed {
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
class
RPCServer
;
class
RPCServer
;
class
RequestHandler
{
class
RequestHandler
{
...
@@ -69,6 +73,11 @@ class RequestHandler {
...
@@ -69,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_
=
g
;
prefetch_var_name_to_prepared_ctx_
=
g
;
}
}
void
SetCheckpointNotifyPreparedCtx
(
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
g
)
{
checkpoint_prepared_ctx_
=
g
;
}
// Used for async.
// Used for async.
void
SetGradToPreparedCtx
(
void
SetGradToPreparedCtx
(
std
::
unordered_map
<
std
::
unordered_map
<
...
@@ -115,6 +124,8 @@ class RequestHandler {
...
@@ -115,6 +124,8 @@ class RequestHandler {
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
prefetch_var_name_to_prepared_ctx_
;
prefetch_var_name_to_prepared_ctx_
;
// used for checkpoint notify
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_prepared_ctx_
;
// Used for async.
// Used for async.
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
b20fa022
...
@@ -22,11 +22,16 @@
...
@@ -22,11 +22,16 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr
char
LOOKUP_TABLE_PATH
[]
=
"kLookupTablePath"
;
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
...
@@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
...
@@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return
true
;
return
true
;
}
}
bool
RequestCheckpointHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
)
{
PADDLE_ENFORCE
(
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
auto
*
lt_var
=
scope
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update var kLookupTablePath to: "
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
return
true
;
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
b20fa022
...
@@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
...
@@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const
std
::
string
&
out_var_name
=
""
)
override
;
const
std
::
string
&
out_var_name
=
""
)
override
;
};
};
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
public:
explicit
RequestCheckpointHandler
(
bool
sync_mode
,
int
checkpoint_notify_id
)
:
RequestHandler
(
sync_mode
)
{
this
->
checkpoint_notify_id
=
checkpoint_notify_id
;
}
virtual
~
RequestCheckpointHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
=
""
)
override
;
private:
int
checkpoint_notify_id
;
};
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
b20fa022
...
@@ -56,6 +56,10 @@ class RPCClient {
...
@@ -56,6 +56,10 @@ class RPCClient {
virtual
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
virtual
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_grpc_deadline
)
=
0
;
// SendComplete tells all the server that current trainer have no more data
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
// to train with other trainers.
...
...
paddle/fluid/operators/distributed/send_recv.proto
浏览文件 @
b20fa022
...
@@ -25,6 +25,8 @@ service SendRecvService {
...
@@ -25,6 +25,8 @@ service SendRecvService {
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
// pre-fetch variable by given variable name and Ids
// pre-fetch variable by given variable name and Ids
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
}
// VariableMessage is serialized paddle variable message.
// VariableMessage is serialized paddle variable message.
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
b20fa022
...
@@ -99,7 +99,8 @@ static int64_t GetTimestamp() {
...
@@ -99,7 +99,8 @@ static int64_t GetTimestamp() {
void
ListenAndServOp
::
RunSyncLoop
(
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
)
const
{
const
std
::
vector
<
int
>
&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
{
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
auto
optimize_blocks
=
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
@@ -208,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -208,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
while
(
true
)
{
while
(
true
)
{
if
(
rpc_service_
->
IsExit
())
{
if
(
rpc_service_
->
IsExit
())
{
LOG
(
INFO
)
<<
"get exit!rpc_processor break!"
;
VLOG
(
4
)
<<
"get exit!rpc_processor break!"
;
break
;
break
;
}
}
...
@@ -223,6 +224,7 @@ static void FillRequestCtx(
...
@@ -223,6 +224,7 @@ static void FillRequestCtx(
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
*
prefetch_ctx
,
*
prefetch_ctx
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_ctx
,
distributed
::
RPCServer
*
rpc_server
)
{
distributed
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetDevCtx
(
dev_ctx
);
...
@@ -230,6 +232,7 @@ static void FillRequestCtx(
...
@@ -230,6 +232,7 @@ static void FillRequestCtx(
h
->
SetProgram
(
program
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
prefetch_ctx
);
h
->
SetPrefetchPreparedCtx
(
prefetch_ctx
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
}
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
...
@@ -245,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -245,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
VLOG
(
4
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
<<
", end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
@@ -255,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -255,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
request_prefetch_handler_
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync_mode
,
checkpoint_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
());
request_send_handler_
.
get
());
...
@@ -262,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -262,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_
.
get
());
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestPrefetch
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
request_prefetch_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
auto
optimize_blocks
=
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
@@ -270,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -270,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto
*
program
=
optimize_blocks
[
0
]
->
Program
();
auto
*
program
=
optimize_blocks
[
0
]
->
Program
();
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
if
(
checkpoint_block_id
!=
-
1
)
{
auto
ctx
=
executor
.
Prepare
(
*
program
,
checkpoint_block_id
);
// see: https://stackoverflow.com/a/14856553
ckpt_pre_context
=
std
::
move
(
ctx
);
}
// prepare for prefetch
// prepare for prefetch
std
::
vector
<
int
>
prefetch_block_id_list
;
std
::
vector
<
int
>
prefetch_block_id_list
;
std
::
unordered_map
<
int
,
std
::
string
>
block_id_to_prefetch_var_name
;
std
::
unordered_map
<
int
,
std
::
string
>
block_id_to_prefetch_var_name
;
...
@@ -300,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -300,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
}
}
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
auto
f
=
&
dev_ctx
,
&
executor
,
program
,
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
prefetch_var_name_to_prepared_ctx
,
rpc_service_
.
get
());
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
ckpt_pre_context
,
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
@@ -320,7 +338,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -320,7 +338,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
// Write to a file of server selected port for python use.
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
);
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
@@ -352,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -352,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
int
>
(
kCheckpointBlockId
,
"BolckID to run save checkpoint on pserer."
)
.
SetDefault
(
-
1
);
}
}
};
};
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
b20fa022
...
@@ -32,6 +32,7 @@ namespace operators {
...
@@ -32,6 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
...
@@ -47,7 +48,8 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -47,7 +48,8 @@ class ListenAndServOp : public framework::OperatorBase {
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
const
std
::
vector
<
int
>&
prefetch_block_id_list
)
const
;
const
std
::
vector
<
int
>&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
;
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
...
@@ -68,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -68,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_prefetch_handler_
;
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_checkpoint_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
};
};
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
b20fa022
...
@@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase {
...
@@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase {
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
platform
::
RecordEvent
record_event
(
Type
(),
dev_ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
dev_ctx
);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
...
@@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase {
...
@@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found"
,
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output variable %s cannot be found"
,
out_var_name
);
out_var_name
);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
LoadLodTensor
(
fin
,
place
,
out_var
);
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
LoadSelectedRows
(
fin
,
place
,
out_var
);
}
else
{
PADDLE_ENFORCE
(
false
,
"Load only support LoDTensor and SelectedRows, %s has wrong type"
,
out_var_name
);
}
}
DeserializeFromStream
(
fin
,
tensor
,
*
dev_ctx
);
void
LoadLodTensor
(
std
::
istream
&
fin
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
());
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
());
...
@@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase {
...
@@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase {
&
fp16_tensor
);
&
fp16_tensor
);
// reset output tensor
// reset output tensor
out_
var
->
Clear
();
var
->
Clear
();
tensor
=
out_
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
set_lod
(
fp16_tensor
.
lod
());
tensor
->
set_lod
(
fp16_tensor
.
lod
());
tensor
->
ShareDataWith
(
fp16_tensor
);
tensor
->
ShareDataWith
(
fp16_tensor
);
}
}
}
}
void
LoadSelectedRows
(
std
::
istream
&
fin
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
auto
*
selectedRows
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
DeserializeFromStream
(
fin
,
selectedRows
,
dev_ctx
);
}
};
};
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddOutput
(
"Out"
,
"The
tensor
need to be loaded"
);
AddOutput
(
"Out"
,
"The
LoDTensor / SelectedRows
need to be loaded"
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"load_as_fp16"
,
"load_as_fp16"
,
"If true, the tensor will be first loaded and then "
"If true, the tensor will be first loaded and then "
...
@@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")"
)
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddComment
(
"Load operator will load a tensor variable from disk file."
);
AddComment
(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/save_load_op_test.cc
浏览文件 @
b20fa022
...
@@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) {
...
@@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) {
save_op
->
Run
(
scope
,
place
);
save_op
->
Run
(
scope
,
place
);
auto
load_var
=
scope
.
Var
(
"out_var"
);
auto
load_var
=
scope
.
Var
(
"out_var"
);
load_var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
auto
load_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
auto
load_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"load"
,
{},
{{
"Out"
,
{
"out_var"
}}},
attrs
);
"load"
,
{},
{{
"Out"
,
{
"out_var"
}}},
attrs
);
load_op
->
Run
(
scope
,
place
);
load_op
->
Run
(
scope
,
place
);
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
b20fa022
...
@@ -22,11 +22,17 @@ limitations under the License. */
...
@@ -22,11 +22,17 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr
char
LOOKUP_TABLE_PATH
[]
=
"kLookupTablePath"
;
// TODO(yuyang18): If the functions below are needed by other files, move them
// TODO(yuyang18): If the functions below are needed by other files, move them
// to paddle::filesystem namespace.
// to paddle::filesystem namespace.
constexpr
char
kSEP
=
'/'
;
constexpr
char
kSEP
=
'/'
;
...
@@ -67,9 +73,27 @@ class SaveOp : public framework::OperatorBase {
...
@@ -67,9 +73,27 @@ class SaveOp : public framework::OperatorBase {
private:
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
iname
=
Input
(
"X"
);
auto
*
var
=
scope
.
FindVar
(
iname
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s for save_op"
,
iname
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SaveLodTensor
(
place
,
var
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
SaveSelectedRows
(
scope
,
place
,
var
);
}
else
{
PADDLE_ENFORCE
(
false
,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type"
,
iname
);
}
}
void
SaveLodTensor
(
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
...
@@ -78,26 +102,19 @@ class SaveOp : public framework::OperatorBase {
...
@@ -78,26 +102,19 @@ class SaveOp : public framework::OperatorBase {
MkDirRecursively
(
DirName
(
filename
).
c_str
());
MkDirRecursively
(
DirName
(
filename
).
c_str
());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
iname
=
Input
(
"X"
);
auto
*
var
=
scope
.
FindVar
(
iname
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s for save_op"
,
iname
);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
"SaveOp only support LoDTensor, %s has wrong type"
,
iname
);
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
());
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
());
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
@@ -112,17 +129,43 @@ class SaveOp : public framework::OperatorBase {
...
@@ -112,17 +129,43 @@ class SaveOp : public framework::OperatorBase {
}
else
{
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
}
}
fout
.
close
();
}
void
SaveSelectedRows
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
auto
*
lt_var
=
scope
.
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
PADDLE_ENFORCE
(
lt_var
!=
nullptr
,
"Can not find variable kLookupTablePath for SaveSelectedRows"
);
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std
::
ofstream
fout
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fout
),
"Cannot open %s to write"
,
filename
);
framework
::
SerializeToStream
(
fout
,
selectedRows
,
dev_ctx
);
fout
.
close
();
}
}
};
};
class
SaveOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SaveOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor ) Input
tensor
to be saved"
);
AddInput
(
"X"
,
"(Tensor ) Input
LoDTensor and SelectedRows
to be saved"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Save operator
Save operator
This operator will serialize and write
a tensor
variable to file on disk.
This operator will serialize and write
LoDTensor / SelectedRows
variable to file on disk.
)DOC"
);
)DOC"
);
AddAttr
<
bool
>
(
"overwrite"
,
AddAttr
<
bool
>
(
"overwrite"
,
"(boolean, default true)"
"(boolean, default true)"
...
@@ -142,9 +185,26 @@ This operator will serialize and write a tensor variable to file on disk.
...
@@ -142,9 +185,26 @@ This operator will serialize and write a tensor variable to file on disk.
}
}
};
};
class
SaveOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
LOOKUP_TABLE_PATH
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SaveOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
save
,
ops
::
SaveOp
,
ops
::
SaveOpProtoMaker
);
REGISTER_OPERATOR
(
save
,
ops
::
SaveOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SaveOpProtoMaker
,
ops
::
SaveOpVarTypeInference
,
ops
::
SaveOpShapeInference
);
python/paddle/fluid/framework.py
浏览文件 @
b20fa022
...
@@ -454,7 +454,7 @@ class Operator(object):
...
@@ -454,7 +454,7 @@ class Operator(object):
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
,
'ncclInit'
,
'channel_create'
,
'channel_close'
,
'channel_send'
,
'ncclInit'
,
'channel_create'
,
'channel_close'
,
'channel_send'
,
'channel_recv'
,
'select'
,
'gen_nccl_id'
'channel_recv'
,
'select'
,
'
checkpoint_notify'
,
'
gen_nccl_id'
}
}
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1214,6 +1214,9 @@ class Block(object):
...
@@ -1214,6 +1214,9 @@ class Block(object):
if
var
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
if
var
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
ret_var
=
self
.
create_var
(
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
type
=
var
.
type
)
name
=
var
.
name
,
persistable
=
var
.
persistable
,
type
=
var
.
type
)
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
type
=
var
.
type
)
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
ret_var
=
self
.
create_var
(
ret_var
=
self
.
create_var
(
name
=
var
.
name
,
name
=
var
.
name
,
...
@@ -1923,7 +1926,7 @@ def get_var(name, program=None):
...
@@ -1923,7 +1926,7 @@ def get_var(name, program=None):
Args:
Args:
name(str): name of the variable
name(str): name of the variable
program(Program|None): program object.
program(Program|None): program object.
If None, default_global_program() will be used.
If None, default_global_program() will be used.
Returns:
Returns:
Variable
Variable
...
...
python/paddle/fluid/io.py
浏览文件 @
b20fa022
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
errno
import
time
import
time
import
shutil
import
shutil
...
@@ -25,7 +26,8 @@ __all__ = [
...
@@ -25,7 +26,8 @@ __all__ = [
'load_persistables'
,
'save_inference_model'
,
'load_inference_model'
,
'load_persistables'
,
'save_inference_model'
,
'load_inference_model'
,
'get_inference_program'
,
'save_checkpoint'
,
'load_checkpoint'
,
'get_inference_program'
,
'save_checkpoint'
,
'load_checkpoint'
,
'clean_checkpoint'
,
'load_persist_vars_without_grad'
,
'clean_checkpoint'
,
'load_persist_vars_without_grad'
,
'save_persist_vars_without_grad'
,
'get_latest_checkpoint_serial'
'load_lookup_table_vars'
,
'save_persist_vars_without_grad'
,
'get_latest_checkpoint_serial'
]
]
...
@@ -795,6 +797,7 @@ def get_parameter_value_by_name(name, executor, program=None):
...
@@ -795,6 +797,7 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
MODEL_DIR
=
"__model__"
LOOKUP_TABLE_DIR
=
"__lookup_table__"
TRAINER_PREFIX
=
"trainer"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
CHECKPOINT_SEPARATOR
=
"_"
...
@@ -804,7 +807,9 @@ def save_checkpoint(executor,
...
@@ -804,7 +807,9 @@ def save_checkpoint(executor,
trainer_id
,
trainer_id
,
trainer_args
=
None
,
trainer_args
=
None
,
main_program
=
None
,
main_program
=
None
,
max_num_checkpoints
=
3
):
max_num_checkpoints
=
3
,
lookup_table
=
None
,
ps_endpoint_list
=
None
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
main_program and then saves these variables to the `checkpoint_dir`
...
@@ -836,6 +841,12 @@ def save_checkpoint(executor,
...
@@ -836,6 +841,12 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
checkpoints.
Default: 3
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns:
Returns:
None
None
...
@@ -852,30 +863,40 @@ def save_checkpoint(executor,
...
@@ -852,30 +863,40 @@ def save_checkpoint(executor,
prog = fluid.default_main_program()
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe,
fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path,
checkpoint_dir=path,
trainer_id=0,
trainer_id=0,
trainer_args=trainer_args,
trainer_args=trainer_args,
main_program=prog,
main_program=prog,
max_num_checkpoints=3)
max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
"""
"""
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
assert
checkpoint_dir
if
trainer_args
:
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
assert
isinstance
(
trainer_args
,
dict
)
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
is_chief
=
trainer_id
==
0
os
.
makedirs
(
checkpoint_dir
)
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
serial
=
get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
trainer_id
==
0
:
if
is_chief
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
ps_endpoint_list
:
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
ps_endpoint_list
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
...
@@ -942,8 +963,9 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
...
@@ -942,8 +963,9 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
clean the checkpoint dir, when the train exits normally,
delete_dir only works when the directory is empty, otherwise, OSError is raised.
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param checkpoint_dir
: param delete_dir
: param delete_dir
...
@@ -1009,6 +1031,56 @@ def load_persist_vars_without_grad(executor,
...
@@ -1009,6 +1031,56 @@ def load_persist_vars_without_grad(executor,
filename
=
None
)
filename
=
None
)
def
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
lookup_table_var
=
var
break
assert
lookup_table_var
is
not
None
lookup_table_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
table_file
=
table_name
+
CHECKPOINT_SEPARATOR
+
str
(
pserver_id
)
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
lookup_table_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
lookup_table_dir
,
table_file
)})
executor
.
run
(
load_prog
)
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
...
@@ -1055,6 +1127,54 @@ def save_persist_vars_without_grad(executor, dirname, program):
...
@@ -1055,6 +1127,54 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success
(
cur_dir
)
_write_success
(
cur_dir
)
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
checkpoint_notify_program
=
Program
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
def
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
def
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
assert
isinstance
(
trainer_args
,
dict
)
...
@@ -1068,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
...
@@ -1068,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
...
@@ -1088,7 +1231,7 @@ def _is_checkpoint_var(var):
...
@@ -1088,7 +1231,7 @@ def _is_checkpoint_var(var):
the checkpoint will not save or load all the variables.
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var
: param var
(Variable)
"""
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
...
@@ -1108,6 +1251,23 @@ def _is_checkpoint_var(var):
...
@@ -1108,6 +1251,23 @@ def _is_checkpoint_var(var):
return
var
.
persistable
return
var
.
persistable
def
_make_chekcpoint_dirs
(
dirs
):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert
dirs
is
not
None
if
os
.
path
.
isfile
(
dirs
):
raise
OSError
(
errno
.
ENOTDIR
,
"dirs path shoule be a Directory."
,
dirs
)
if
not
os
.
path
.
isdir
(
dirs
):
try
:
os
.
makedirs
(
dirs
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
EEXIST
:
raise
err
def
_get_dir_serial
(
dirname
):
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
...
@@ -1121,29 +1281,27 @@ def _get_dir_serial(dirname):
...
@@ -1121,29 +1281,27 @@ def _get_dir_serial(dirname):
def
_get_serial_dir
(
dirname
,
serial
):
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
if
not
os
.
path
.
isdir
(
serial_dir
):
os
.
makedirs
(
serial_dir
)
return
serial_dir
return
serial_dir
def
_get_model_dir
(
dirname
):
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
_make_chekcpoint_dirs
(
model_dir
)
return
model_dir
if
not
os
.
path
.
isdir
(
model_dir
):
os
.
makedirs
(
model_dir
)
return
model_dir
def
_get_lookuptable_dir
(
dirname
):
lookuptable_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
_make_chekcpoint_dirs
(
lookuptable_dir
)
return
lookuptable_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
_make_chekcpoint_dirs
(
trainer_dir
)
if
not
os
.
path
.
isdir
(
trainer_dir
):
os
.
makedirs
(
trainer_dir
)
return
trainer_dir
return
trainer_dir
...
@@ -1162,7 +1320,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
...
@@ -1162,7 +1320,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
serials
=
serials
[
max_num_checkpoints
:]
serials
=
serials
[
max_num_checkpoints
:]
for
serial
in
serials
:
for
serial
in
serials
:
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
shutil
.
rmtree
(
cur_dir
)
try
:
shutil
.
rmtree
(
cur_dir
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
ENOENT
:
raise
err
def
_write_success
(
dirname
):
def
_write_success
(
dirname
):
...
...
python/paddle/fluid/trainer.py
浏览文件 @
b20fa022
...
@@ -119,27 +119,20 @@ class CheckpointConfig(object):
...
@@ -119,27 +119,20 @@ class CheckpointConfig(object):
max_num_checkpoints
=
3
,
max_num_checkpoints
=
3
,
epoch_interval
=
1
,
epoch_interval
=
1
,
step_interval
=
10
):
step_interval
=
10
):
if
checkpoint_dir
is
None
:
self
.
checkpoint_dir
=
os
.
getcwd
()
else
:
self
.
checkpoint_dir
=
checkpoint_dir
self
.
max_num_checkpoints
=
max_num_checkpoints
if
epoch_interval
<
1
:
self
.
epoch_interval
=
1
else
:
self
.
epoch_interval
=
epoch_interval
if
step_interval
<
1
:
assert
epoch_interval
>=
1
self
.
step_interval
=
10
assert
step_interval
>=
1
else
:
self
.
step_interval
=
step_interval
self
.
checkpoint_dir
=
checkpoint_dir
\
if
checkpoint_dir
is
not
None
else
os
.
getcwd
()
self
.
max_num_checkpoints
=
max_num_checkpoints
self
.
epoch_interval
=
epoch_interval
self
.
step_interval
=
step_interval
self
.
epoch_id
=
0
self
.
epoch_id
=
0
self
.
step_id
=
0
self
.
step_id
=
0
self
.
load_serial
=
None
self
.
load_serial
=
None
self
.
is_pserver
=
False
self
.
pserver_id
=
None
self
.
lookup_table_name
=
None
def
check_and_get_place
(
place
):
def
check_and_get_place
(
place
):
...
@@ -290,13 +283,20 @@ class Trainer(object):
...
@@ -290,13 +283,20 @@ class Trainer(object):
self
.
checkpoint_cfg
.
load_serial
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
startup_program
)
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
is_pserver
:
if
not
self
.
checkpoint_cfg
.
pserver_id
:
epoch_id
,
step_id
=
io
.
load_trainer_args
(
epoch_id
,
step_id
=
io
.
load_trainer_args
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
trainer_id
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
trainer_id
,
self
.
_get_checkpoint_load_args
())
self
.
_get_checkpoint_load_args
())
self
.
checkpoint_cfg
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint_cfg
.
step_id
=
int
(
step_id
)
self
.
checkpoint_cfg
.
step_id
=
int
(
step_id
)
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
io
.
load_lookup_table_vars
(
exe
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
startup_program
,
self
.
checkpoint_cfg
.
pserver_id
,
self
.
checkpoint_cfg
.
lookup_table_name
)
if
param_path
and
os
.
path
.
isdir
(
param_path
):
if
param_path
and
os
.
path
.
isdir
(
param_path
):
# load params from param_path into scope
# load params from param_path into scope
...
@@ -366,7 +366,10 @@ class Trainer(object):
...
@@ -366,7 +366,10 @@ class Trainer(object):
self
.
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
self
.
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
if
self
.
checkpoint_cfg
:
if
self
.
checkpoint_cfg
:
self
.
is_pserver
=
True
pserver_id
=
eplist
.
index
(
current_endpoint
)
self
.
checkpoint_cfg
.
pserver_id
=
pserver_id
if
t
.
has_distributed_lookup_table
:
self
.
checkpoint_cfg
.
lookup_table_name
=
t
.
table_name
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
...
@@ -566,7 +569,8 @@ class Trainer(object):
...
@@ -566,7 +569,8 @@ class Trainer(object):
def
_save_checkpoint
(
self
,
epoch_id
,
step_id
):
def
_save_checkpoint
(
self
,
epoch_id
,
step_id
):
assert
self
.
checkpoint_cfg
assert
self
.
checkpoint_cfg
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
exe
=
executor
.
Executor
(
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
io
.
save_checkpoint
(
io
.
save_checkpoint
(
executor
=
exe
,
executor
=
exe
,
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b20fa022
...
@@ -471,6 +471,8 @@ class DistributeTranspiler(object):
...
@@ -471,6 +471,8 @@ class DistributeTranspiler(object):
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
)
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
)
prefetch_var_name_to_block_id
=
self
.
_create_prefetch_block
(
prefetch_var_name_to_block_id
=
self
.
_create_prefetch_block
(
pserver_index
,
pserver_program
,
table_opt_block
)
pserver_index
,
pserver_program
,
table_opt_block
)
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
pserver_program
,
table_opt_block
.
idx
)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
# not be executed, so it's safe to use optimize_block to hold the place
...
@@ -489,6 +491,7 @@ class DistributeTranspiler(object):
...
@@ -489,6 +491,7 @@ class DistributeTranspiler(object):
if
len
(
prefetch_var_name_to_block_id
)
>
0
:
if
len
(
prefetch_var_name_to_block_id
)
>
0
:
attrs
[
'prefetch_var_name_to_block_id'
]
\
attrs
[
'prefetch_var_name_to_block_id'
]
\
=
prefetch_var_name_to_block_id
=
prefetch_var_name_to_block_id
attrs
[
'checkpint_block_id'
]
=
checkpoint_block_id
# step5 append the listen_and_serv op
# step5 append the listen_and_serv op
pserver_program
.
global_block
().
append_op
(
pserver_program
.
global_block
().
append_op
(
...
@@ -910,6 +913,27 @@ class DistributeTranspiler(object):
...
@@ -910,6 +913,27 @@ class DistributeTranspiler(object):
return
table_opt_block
return
table_opt_block
def
_create_checkpoint_save_block
(
self
,
pserver_program
,
pre_block_idx
):
"""
create a new block to handle save checkpoint.
"""
import
os
pserver_program
.
global_block
().
create_var
(
name
=
"kLookupTablePath"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
checkpoint_save_block
=
pserver_program
.
create_block
(
pre_block_idx
)
# this 'file_path' do not be used in save lookup table variable
checkpoint_save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
self
.
table_name
]},
outputs
=
{},
attrs
=
{
'file_path'
:
"none"
})
return
checkpoint_save_block
.
idx
def
_create_vars_from_blocklist
(
self
,
def
_create_vars_from_blocklist
(
self
,
program
,
program
,
block_list
,
block_list
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录