Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f70096a2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f70096a2
编写于
9月 23, 2020
作者:
S
seiriosPlus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mode for save
上级
11d17938
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
34 addition
and
23 deletion
+34
-23
paddle/fluid/operators/distributed/brpc/brpc_client.cc
paddle/fluid/operators/distributed/brpc/brpc_client.cc
+1
-0
paddle/fluid/operators/distributed/brpc/brpc_client.h
paddle/fluid/operators/distributed/brpc/brpc_client.h
+1
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+2
-0
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+5
-5
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+2
-1
paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc
...e/fluid/operators/distributed_ops/checkpoint_notify_op.cc
+12
-8
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
...dle/distributed/fleet/runtime/parameter_server_runtime.py
+10
-7
未找到文件。
paddle/fluid/operators/distributed/brpc/brpc_client.cc
浏览文件 @
f70096a2
...
...
@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
VarHandlePtr
BRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dirname
,
const
std
::
string
&
varname
,
const
int
mode
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
varname
);
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.h
浏览文件 @
f70096a2
...
...
@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient {
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dirname
,
const
std
::
string
&
varname
,
const
std
::
string
&
varname
,
const
int
mode
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
Wait
()
override
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
f70096a2
...
...
@@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
VarHandlePtr
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dirname
,
const
std
::
string
&
varname
,
const
int
mode
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
...
...
@@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
varname
);
req
.
set_table_name
(
std
::
to_string
(
mode
));
req
.
set_out_varname
(
dirname
);
platform
::
RecordRPCEvent
record_event
(
method
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
f70096a2
...
...
@@ -247,7 +247,7 @@ class GRPCClient : public RPCClient {
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dirname
,
const
std
::
string
&
varname
,
const
std
::
string
&
varname
,
const
int
mode
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncDistributeNotify
(
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
f70096a2
...
...
@@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"receive save var "
<<
varname
<<
" with path "
<<
out_var_name
;
int
mode
=
std
::
stoi
(
out_var_name
);
VLOG
(
4
)
<<
"receive save var "
<<
varname
<<
" with path "
<<
out_var_name
<<
" mode "
<<
mode
;
auto
*
ins
=
distributed
::
LargeScaleKV
::
GetInstance
();
ins
->
Get
(
varname
)
->
Save
(
out_var_name
);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
ins
->
Get
(
varname
)
->
Save
(
out_var_name
,
mode
);
return
true
;
}
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
f70096a2
...
...
@@ -78,7 +78,8 @@ class RPCClient {
virtual
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dirname
,
const
std
::
string
&
varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
const
std
::
string
&
varname
,
const
int
mode
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
...
...
paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc
浏览文件 @
f70096a2
...
...
@@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
std
::
string
dirname
=
Attr
<
std
::
string
>
(
"dirname"
);
std
::
string
varname
=
Attr
<
std
::
string
>
(
"varname"
);
auto
is_slice
=
Attr
<
bool
>
(
"is_slice"
);
VLOG
(
1
)
<<
"is_slice: "
<<
is_slice
;
auto
mode
=
Attr
<
int
>
(
"mode"
);
if
(
mode
!=
0
&&
mode
!=
1
&&
mode
!=
2
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"mode expected in [0/1/2], but got %d"
,
mode
));
}
std
::
vector
<
std
::
string
>
slice_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"slice_varnames"
);
...
...
@@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
auto
save_path
=
string
::
Sprintf
(
"%s/%s/%s"
,
dirname
,
varname
,
slice_varnames
[
i
]);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
save_path
,
remote_varnames
[
i
]
);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
save_path
,
remote_varnames
[
i
],
mode
);
VLOG
(
3
)
<<
"checkpoint notify sending with path: "
<<
save_path
<<
" and var:"
<<
slice_varnames
[
i
]
<<
" to "
<<
epmap
[
i
];
<<
" and var:"
<<
slice_varnames
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" with mode "
<<
mode
;
}
PADDLE_ENFORCE_EQ
(
rpc_client
->
Wait
(),
true
,
...
...
@@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"slice_varnames"
,
"(string vector) the slice vars need to be saved"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"remote_varnames"
,
"(string vector) the slice vars need to be saved"
);
AddAttr
<
bool
>
(
"is_slice"
,
"is_slice=True means the var has been slice by parameter server"
);
AddAttr
<
int
>
(
"mode"
,
"mode=0/1/2 means nothing/save base/save delta"
)
.
SetDefault
(
0
);
AddComment
(
R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
...
...
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
浏览文件 @
f70096a2
...
...
@@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase):
executor
.
run
(
prog
)
return
context
.
keys
()
def
_save_distributed_params
(
self
,
executor
,
dirname
,
context
,
main_program
):
def
_save_distributed_params
(
self
,
executor
,
dirname
,
context
,
mode
):
prog
=
Program
()
block
=
prog
.
global_block
()
...
...
@@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase):
type
=
'checkpoint_notify'
,
attrs
=
{
"varname"
:
name
,
"
is_slice"
:
Tru
e
,
"
mode"
:
mod
e
,
"slice_varnames"
:
var_ctx
.
split_varnames
(),
"remote_varnames"
:
var_ctx
.
split_varnames
(),
"endpoints"
:
var_ctx
.
split_endpoints
(),
...
...
@@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase):
executor
.
run
(
prog
)
return
context
.
keys
()
def
_save_distributed_persistables
(
self
,
executor
,
dirname
,
main_program
):
def
_save_distributed_persistables
(
self
,
executor
,
dirname
,
main_program
,
mode
):
dense_ctx
=
self
.
compiled_strategy
.
get_communicator_recv_context
(
recv_type
=
1
)
...
...
@@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase):
executor
,
dirname
,
sparse_ctx
,
main_program
)
recv_distributed_varnames
=
self
.
_save_distributed_params
(
executor
,
dirname
,
distributed_ctx
,
m
ain_program
)
executor
,
dirname
,
distributed_ctx
,
m
ode
)
saved_varnames
=
recv_dense_varnames
+
list
(
recv_sparse_varnames
)
+
list
(
recv_distributed_varnames
)
...
...
@@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase):
executor
,
dirname
,
main_program
=
None
,
mode
=
0
,
**
kwargs
):
"""
This function filters out all variables with `persistable==True` from the
...
...
@@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
self
.
_save_distributed_persistables
(
executor
,
dirname
,
main_program
)
self
.
_save_distributed_persistables
(
executor
,
dirname
,
main_program
,
mode
)
def
_ps_inference_save_inference_model
(
self
,
executor
,
...
...
@@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase):
program
=
Program
.
parse_from_string
(
program_desc_str
)
program
.
_copy_dist_param_info_from
(
fluid
.
default_main_program
())
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
program
)
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
program
,
mode
=
0
)
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录