Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9ba2ae10
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9ba2ae10
编写于
8月 21, 2020
作者:
S
seiriosPlus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize init from pserver
上级
4b4e558a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
42 addition
and
56 deletion
+42
-56
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+37
-53
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+5
-1
python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py
...ddle/distributed/fleet/meta_optimizers/async_optimizer.py
+0
-1
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+0
-1
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
9ba2ae10
...
...
@@ -74,8 +74,12 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
}
else
{
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
InitParams
();
}
void
AsyncCommunicator
::
InitParams
()
{
RecvNoBarrier
();
}
AsyncCommunicator
::~
AsyncCommunicator
()
{
running_
=
false
;
if
(
main_thread_
)
main_thread_
->
join
();
...
...
@@ -721,7 +725,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
t_timestamp
->
data
<
float
>
());
}
void
GeoCommunicator
::
Init
()
{
void
GeoCommunicator
::
Init
Params
()
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
recv_varname_to_ctx_
.
size
());
...
...
@@ -744,12 +748,17 @@ void GeoCommunicator::Init() {
}
void
GeoCommunicator
::
InitDense
(
const
std
::
string
varname
)
{
auto
*
var
=
old_scope_
->
Var
(
varname
);
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
ctx
=
recv_varname_to_ctx_
.
at
(
varname
);
auto
recv
=
distributed
::
ParameterRecv
<
float
>
();
recv
(
ctx
,
*
old_scope_
);
recv
(
ctx
,
*
recv_scope_
);
auto
*
global_var
=
recv_scope_
->
FindVar
(
varname
);
global_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
old_var
=
old_scope_
->
Var
(
varname
);
old_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
CopyVariable
(
*
global_var
,
old_var
);
VLOG
(
1
)
<<
"init dense variable "
<<
varname
<<
" done"
;
}
...
...
@@ -781,68 +790,43 @@ void GeoCommunicator::InitSparse() {
LargeScaleKV
::
Init
(
metas
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
framework
::
Scope
&
local_scope
=
send_scope_
->
NewScope
();
for
(
auto
&
meta
:
metas
)
{
auto
&
ctx
=
recv_varname_to_ctx_
.
at
(
meta
.
name
);
auto
pserver_num
=
ctx
.
splited_varnames
.
size
();
for
(
size_t
i
=
0
;
i
<
ctx
.
splited_varnames
.
size
();
i
++
)
{
auto
&
recv_var_name
=
ctx
.
splited_varnames
[
i
];
auto
*
var
=
local_scope
.
Var
(
recv_var_name
);
var
->
GetMutable
<
framework
::
LoDTensor
>
();
distributed
::
VarHandlePtr
ret
;
ret
=
rpc_client
->
AsyncGetVarNoBarrier
(
ctx
.
epmap
[
i
],
cpu_ctx
,
local_scope
,
recv_var_name
,
recv_var_name
);
auto
*
recv_var
=
local_scope
.
FindVar
(
recv_var_name
);
auto
&
recv_t
=
recv_var
->
Get
<
framework
::
LoDTensor
>
();
auto
recv
=
distributed
::
ParameterRecv
<
float
>
();
auto
width
=
recv_t
.
dims
()[
1
];
auto
rows
=
recv_t
.
dims
()[
0
];
auto
*
global_var
=
recv_scope_
->
FindVar
(
meta
.
name
);
auto
global_value
=
global_var
->
Get
<
framework
::
LoDTensor
>
();
auto
rows
=
global_value
.
dims
()[
0
];
auto
dim1
=
global_value
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
width
,
meta
.
value_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"sparse params do not match"
));
recv
(
ctx
,
*
recv_scope_
);
VLOG
(
1
)
<<
"recv "
<<
meta
.
name
<<
" with global scope for init"
;
std
::
vector
<
int64_t
>
ids
;
auto
n_rows
=
global_var
->
Get
<
framework
::
LoDTensor
>
().
dims
()[
0
]
;
for
(
int
x
=
0
;
x
<
rows
;
x
++
)
{
ids
.
push_back
(
x
*
pserver_num
+
i
);
}
PADDLE_ENFORCE_EQ
(
rows
,
n_rows
,
platform
::
errors
::
InvalidArgument
(
"global var: %s origin dim must equal recved rows"
,
meta
.
name
));
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>
*>>
values
;
auto
*
ins
=
distributed
::
LargeScaleKV
::
GetInstance
(
);
std
::
vector
<
int64_t
>
ids
(
rows
)
;
std
::
iota
(
ids
.
begin
(),
ids
.
end
(),
0
);
ins
->
Get
(
meta
.
name
)
->
Init
(
ids
);
ins
->
Get
(
meta
.
name
)
->
Get
(
ids
,
{
"Param"
},
&
values
)
;
auto
*
ins
=
distributed
::
LargeScaleKV
::
GetInstance
(
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>
*>>
values
;
PADDLE_ENFORCE_NE
(
ret
->
Wait
(),
0U
,
platform
::
errors
::
ExecutionTimeout
(
"internal error in RPCClient"
)
);
ins
->
Get
(
meta
.
name
)
->
Init
(
ids
);
ins
->
Get
(
meta
.
name
)
->
Get
(
ids
,
{
"Param"
},
&
values
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
float
>
(
paddle
::
platform
::
CPUDeviceContext
());
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
float
>
(
paddle
::
platform
::
CPUDeviceContext
());
for
(
size_t
k
=
0
;
k
<
ids
.
size
();
++
k
)
{
blas
.
VCOPY
(
width
,
recv_t
.
data
<
float
>
()
+
k
*
width
,
values
[
k
][
0
]
->
data
());
}
local_scope
.
EraseVars
({
recv_var_name
});
for
(
auto
&
id
:
ids
)
{
blas
.
VCOPY
(
dim1
,
global_value
.
data
<
float
>
()
+
k
*
width
,
values
[
id
][
0
]
->
data
());
}
}
send_scope_
->
DeleteScope
(
&
local_scope
);
VLOG
(
3
)
<<
"init sparse variable done"
;
}
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
9ba2ae10
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
...
...
@@ -29,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
...
...
@@ -279,6 +281,8 @@ class AsyncCommunicator : public Communicator {
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
void
InitParams
();
void
MainThread
();
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
...
...
@@ -435,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
void
RecvDense
(
const
std
::
string
&
varname
);
void
Init
();
void
Init
Params
();
void
InitSparse
();
...
...
python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py
浏览文件 @
9ba2ae10
...
...
@@ -65,7 +65,6 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
# for startup program
_startup
=
worker
.
fake_init_ops_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
init_from_server_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
delet_extra_optimizes_pass
(
_startup
,
compiled_config
)
else
:
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
9ba2ae10
...
...
@@ -771,7 +771,6 @@ class ParameterServerOptimizer(DistributedOptimizer):
# for startup program
_startup
=
worker
.
fake_init_ops_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
init_from_server_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
delet_extra_optimizes_pass
(
_startup
,
compiled_config
)
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录