Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
728ec1b4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
728ec1b4
编写于
9月 30, 2019
作者:
C
Chengmo
提交者:
GitHub
9月 30, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GEO-SGD distribute training algorithm (#20018)
* refector geo sgd & communicator
上级
5365cd2f
变更
17
展开全部
显示空白变更内容
内联
并排
Showing
17 changed file
with
1518 addition
and
104 deletion
+1518
-104
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+566
-4
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+148
-10
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+64
-8
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+122
-55
paddle/fluid/operators/distributed/parameter_send.h
paddle/fluid/operators/distributed/parameter_send.h
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+58
-5
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+6
-2
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+15
-1
paddle/fluid/pybind/communicator_py.h
paddle/fluid/pybind/communicator_py.h
+4
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+18
-2
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+20
-5
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
+42
-11
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
+100
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+4
-0
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
+348
-0
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
728ec1b4
此差异已折叠。
点击以展开。
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
728ec1b4
...
...
@@ -14,16 +14,17 @@ limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
...
...
@@ -170,6 +171,11 @@ class Communicator {
virtual
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
Recv
()
=
0
;
virtual
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
...
...
@@ -179,6 +185,13 @@ class Communicator {
virtual
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
=
0
;
// for geo-sgd
virtual
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
=
0
;
static
Communicator
*
GetInstance
()
{
return
communicator_
.
get
();
}
static
std
::
shared_ptr
<
Communicator
>
GetInstantcePtr
()
{
...
...
@@ -194,6 +207,26 @@ class Communicator {
return
communicator_
.
get
();
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithProgram
<
T
>
,
program
,
recv_scope
);
return
communicator_
.
get
();
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithTranspilerInfo
<
T
>
,
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
return
communicator_
.
get
();
}
// Init is called by InitInstance.
template
<
typename
T
>
static
void
InitWithRpcCtx
(
const
RpcCtxMap
&
send_varname_to_ctx
,
...
...
@@ -206,14 +239,6 @@ class Communicator {
}
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithProgram
<
T
>
,
program
,
recv_scope
);
return
communicator_
.
get
();
}
template
<
typename
T
>
static
void
InitWithProgram
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
...
...
@@ -223,12 +248,28 @@ class Communicator {
}
}
template
<
typename
T
>
static
void
InitWithTranspilerInfo
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
->
InitImpl
(
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
}
}
protected:
bool
running_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
};
using
SparseIdsMap
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
int64_t
>>
;
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
{}
...
...
@@ -251,6 +292,16 @@ class AsyncCommunicator : public Communicator {
void
SendThread
();
void
RecvThread
();
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
override
;
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
...
...
@@ -266,6 +317,93 @@ class AsyncCommunicator : public Communicator {
std
::
atomic_uint
grad_num_
{
0
};
// the num of gradient sent since last recv
};
class
GeoSgdCommunicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
{}
~
GeoSgdCommunicator
();
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
override
;
void
Start
()
override
;
void
Stop
()
override
;
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
override
;
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
Recv
()
override
;
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
override
;
private:
void
SendThread
();
void
RecvAll
();
std
::
unordered_set
<
int64_t
>
SparseIdsMerge
(
const
std
::
vector
<
SparseIdsMap
>&
ids_send_vec
,
const
std
::
string
&
var_name
);
void
SendUpdateDenseVars
(
const
std
::
string
&
var_name
);
void
SendUpdateSparseVars
(
const
std
::
string
&
var_name
,
const
std
::
unordered_set
<
int64_t
>&
ids_table
);
void
RecvUpdateVars
(
const
std
::
string
&
var_name
);
void
GeoSgdDenseParamInit
(
framework
::
Scope
*
scope_x
,
framework
::
Scope
*
scope_y
,
const
std
::
string
var_name
);
void
GeoSgdSparseParamInit
(
framework
::
Scope
*
scope_x
,
framework
::
Scope
*
scope_y
,
const
std
::
string
var_name
);
const
std
::
string
VarToDeltaVar
(
const
std
::
string
var_name
)
{
std
::
string
delta_name
=
var_name
;
const
std
::
string
send_name
=
delta_name
.
append
(
".delta"
);
return
send_name
;
}
const
std
::
string
DeltaVarToVar
(
const
std
::
string
var_name
)
{
std
::
string
origin_name
=
var_name
;
origin_name
.
erase
(
origin_name
.
find
(
".delta"
),
6
);
const
std
::
string
param_name
=
origin_name
;
return
param_name
;
}
private:
int
trainer_nums_
=
1
;
int
geo_need_push_nums_
=
100
;
bool
is_geo_sgd_
=
false
;
Scope
*
training_scope_
;
std
::
shared_ptr
<
Scope
>
delta_scope_
;
// parameter local delta: recv - old
std
::
shared_ptr
<
Scope
>
old_scope_
;
// parameter local, storage the param after last recv
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
// parameter on pserver,gloabl scope
RpcCtxMap
send_varname_to_ctx_
;
RpcCtxMap
recv_varname_to_ctx_
;
std
::
atomic_uint
have_push_
{
0
};
std
::
unordered_map
<
std
::
string
,
bool
>
var_list_
;
// if var is sparse, using selected rows, bool=true
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
SparseIdsMap
>>>
need_push_queue_
;
std
::
vector
<
SparseIdsMap
>
ids_send_vec_
;
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
recv_threadpool_
{
nullptr
};
std
::
unique_ptr
<
std
::
thread
>
send_thread_
{
nullptr
};
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
728ec1b4
...
...
@@ -42,7 +42,7 @@ using DDim = framework::DDim;
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"ParameterRecv in "
<<
rpc_ctx
.
var_name
;
VLOG
(
2
)
<<
"ParameterRecv in "
<<
rpc_ctx
.
var_name
;
std
::
unique_ptr
<
framework
::
Scope
>
local_scope
=
scope
.
NewTmpScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -54,15 +54,24 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto
*
recv_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
// recv all vars to local scope
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
()
||
recv_var
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
recv_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
local_scope
->
Var
(
recv_var_name
);
VLOG
(
3
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
VLOG
(
4
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
// sparse param in recv_scope is LoDTensor
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
.
get
(),
recv_var_name
,
recv_var_name
));
*
local_scope
.
get
(),
recv_var_name
,
recv_var_name
));
}
else
{
// sparse param in pserver_scope is SelectedRows
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
.
get
(),
recv_var_name
,
recv_var_name
,
recv_var_name
));
}
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
...
...
@@ -72,7 +81,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
}
// concat recved tensor into one var
{
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
size_t
output_offset
=
0
;
size_t
row_offset
=
0
;
framework
::
Tensor
*
recv_tensor
=
...
...
@@ -126,9 +135,56 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
LOG
(
FATAL
)
<<
"recv_numel: "
<<
recv_numel
<<
" acture numel: "
<<
numel
;
}
PADDLE_ENFORCE_EQ
(
recv_numel
,
numel
);
}
else
if
(
recv_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
cpu_place
=
platform
::
CPUPlace
();
auto
*
slr
=
recv_var
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_value
()
->
mutable_data
<
float
>
({{}},
cpu_place
);
int64_t
width
=
0
;
int64_t
height
=
0
;
std
::
vector
<
int64_t
>
new_rows
{};
// trans sparse ids from local to global
std
::
vector
<
int64_t
>
abs_sections
=
ToAbsoluteSection
(
rpc_ctx
.
height_sections
);
for
(
int
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
recv_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
auto
*
var
=
local_scope
->
FindVar
(
recv_var_name
);
auto
*
var_slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
var_slr_row
=
var_slr
->
mutable_rows
();
width
=
var_slr
->
mutable_value
()
->
dims
()[
1
];
height
+=
var_slr
->
height
();
auto
row_offset
=
abs_sections
[
i
];
VLOG
(
4
)
<<
"Recv split_var "
<<
recv_var_name
<<
" Row size "
<<
var_slr_row
->
size
();
for
(
size_t
j
=
0
;
j
<
var_slr_row
->
size
();
j
++
)
{
new_rows
.
push_back
(
row_offset
+
(
*
var_slr_row
)[
j
]);
}
}
slr
->
set_rows
(
new_rows
);
slr
->
set_height
(
height
);
slr
->
mutable_value
()
->
mutable_data
<
float
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
slr
->
mutable_rows
()
->
size
()),
width
}),
cpu_place
);
auto
*
slr_data
=
slr
->
mutable_value
()
->
data
<
float
>
();
size_t
row_offset
=
0
;
for
(
auto
&
recv_var_name
:
rpc_ctx
.
splited_var_names
)
{
auto
*
var
=
local_scope
->
FindVar
(
recv_var_name
);
auto
*
var_slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
var_slr_row
=
var_slr
->
mutable_rows
();
auto
var_slr_row_size
=
var_slr_row
->
size
();
auto
*
var_slr_data
=
var_slr
->
mutable_value
()
->
data
<
float
>
();
memcpy
(
slr_data
+
row_offset
*
width
,
var_slr_data
,
sizeof
(
float
)
*
width
*
var_slr_row_size
);
row_offset
+=
var_slr_row_size
;
}
}
VLOG
(
3
)
<<
"ParameterRecv out "
<<
rpc_ctx
.
var_name
;
VLOG
(
2
)
<<
"ParameterRecv out "
<<
rpc_ctx
.
var_name
;
}
template
struct
ParameterRecv
<
float
>;
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
728ec1b4
...
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
...
...
@@ -28,6 +28,7 @@
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -38,9 +39,44 @@ using LoDTensor = framework::LoDTensor;
using
SelectedRows
=
framework
::
SelectedRows
;
using
DDim
=
framework
::
DDim
;
typedef
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
EP_SPLIT_TABLE_PAIRS
;
inline
EP_SPLIT_TABLE_PAIRS
GetMultiFieldRpcContext
(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
,
int
multi_parts
)
{
EP_SPLIT_TABLE_PAIRS
table_pairs
;
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
if
(
send_var
->
IsType
<
framework
::
SelectedRows
>
())
{
PADDLE_ENFORCE_GT
(
multi_parts
,
0
,
"multi_parts must >=1"
);
if
(
multi_parts
==
1
)
{
for
(
int
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
table_pairs
.
push_back
(
std
::
make_pair
(
rpc_ctx
.
epmap
[
i
],
rpc_ctx
.
splited_var_names
[
i
]));
}
}
else
{
for
(
int
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
for
(
int
x
=
0
;
x
<
multi_parts
;
x
++
)
{
auto
table
=
string
::
Sprintf
(
"%s@%d@PIECE"
,
rpc_ctx
.
splited_var_names
[
i
],
x
);
table_pairs
.
push_back
(
std
::
make_pair
(
rpc_ctx
.
epmap
[
i
],
table
));
}
}
}
}
else
if
(
send_var
->
IsType
<
framework
::
LoDTensor
>
())
{
PADDLE_THROW
(
"GetMultiFieldRpcContext can not support LoDTensor current!"
);
}
else
{
PADDLE_THROW
(
"GetMultiFieldRpcContext unsupported var type!"
);
}
return
table_pairs
;
}
// namespace distributed
template
<
typename
T
>
void
ParameterSend
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
,
bool
sync
)
{
const
framework
::
Scope
&
scope
,
bool
sync
,
int
multi_parts
)
{
std
::
unique_ptr
<
framework
::
Scope
>
local_scope
=
scope
.
NewTmpScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -49,9 +85,12 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
size_t
out_num
=
rpc_ctx
.
splited_var_names
.
size
();
if
(
send_var
->
IsType
<
framework
::
LoDTensor
>
())
{
size_t
out_num
=
rpc_ctx
.
splited_var_names
.
size
();
if
(
out_num
>
1
)
{
auto
&
send_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
send_tensor_dims
=
send_tensor
.
dims
();
...
...
@@ -77,6 +116,24 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
row_offset
+=
outs_dims
[
i
][
0
];
}
}
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
send_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
VLOG
(
4
)
<<
"send var name: "
<<
send_var_name
;
auto
&
endpoint
=
rpc_ctx
.
epmap
[
i
];
VLOG
(
4
)
<<
"send var endpoint: "
<<
endpoint
;
VLOG
(
4
)
<<
"need send: "
<<
NeedSend
(
*
local_scope
.
get
(),
send_var_name
);
if
(
NeedSend
(
*
local_scope
.
get
(),
send_var_name
))
{
VLOG
(
3
)
<<
"sending "
<<
send_var_name
<<
" to "
<<
endpoint
;
rets
.
push_back
(
rpc_client
->
AsyncSendVar
(
endpoint
,
cpu_ctx
,
*
local_scope
.
get
(),
send_var_name
));
VLOG
(
4
)
<<
"send var "
<<
send_var_name
<<
" async handle done"
;
}
else
{
VLOG
(
3
)
<<
"don't send non-initialized variable: "
<<
rpc_ctx
.
splited_var_names
[
i
];
}
}
}
else
if
(
send_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
send_slr
=
send_var
->
Get
<
framework
::
SelectedRows
>
();
auto
abs_sections
=
ToAbsoluteSection
(
rpc_ctx
.
height_sections
);
...
...
@@ -85,84 +142,94 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
std
::
vector
<
std
::
vector
<
size_t
>>
outs_rows_idx
;
std
::
vector
<
std
::
vector
<
size_t
>>
outs_dense_idx
;
outs_rows_idx
.
resize
(
out_num
);
outs_dense_idx
.
resize
(
out_num
);
auto
table_pairs
=
GetMultiFieldRpcContext
(
rpc_ctx
,
scope
,
multi_parts
);
outs_rows_idx
.
resize
(
table_pairs
.
size
());
outs_dense_idx
.
resize
(
table_pairs
.
size
());
auto
row_numel
=
send_slr
.
value
().
numel
()
/
send_slr
.
value
().
dims
()[
0
];
auto
*
src
=
send_slr
.
value
().
data
<
T
>
();
// create output var in local scope
std
::
vector
<
framework
::
SelectedRows
*>
outs
;
for
(
auto
&
name
:
rpc_ctx
.
splited_var_names
)
{
auto
*
out
=
local_scope
->
Var
(
name
)
->
GetMutable
<
framework
::
SelectedRows
>
();
for
(
auto
&
table
:
table_pairs
)
{
auto
*
out
=
local_scope
->
Var
(
table
.
second
)
->
GetMutable
<
framework
::
SelectedRows
>
();
outs
.
push_back
(
out
);
}
// split rows index into output sparse vars
for
(
size_t
i
=
0
;
i
<
send_rows
.
size
();
++
i
)
{
size_t
out_idx
=
GetSectionIndex
(
send_rows
[
i
],
abs_sections
);
auto
ep_idx
=
GetSectionIndex
(
send_rows
[
i
],
abs_sections
);
auto
table_idx
=
send_rows
[
i
]
%
multi_parts
;
auto
out_idx
=
ep_idx
*
multi_parts
+
table_idx
;
outs_rows_idx
[
out_idx
].
push_back
(
send_rows
[
i
]);
outs_dense_idx
[
out_idx
].
push_back
(
i
);
}
auto
place
=
platform
::
CPUPlace
();
for
(
size_t
i
=
0
;
i
<
outs_rows_idx
.
size
();
++
i
)
{
auto
rows_idx
=
outs_rows_idx
[
i
];
outs
[
i
]
->
set_height
(
rpc_ctx
.
height_sections
[
i
]);
for
(
int
ctx
=
0
;
ctx
<
rpc_ctx
.
splited_var_names
.
size
();
ctx
++
)
{
for
(
int
part
=
0
;
part
<
multi_parts
;
part
++
)
{
auto
out_idx
=
ctx
*
multi_parts
+
part
;
auto
rows_idx
=
outs_rows_idx
[
out_idx
];
auto
dims
=
send_slr
.
GetCompleteDims
();
dims
[
0
]
=
rows_idx
.
size
();
outs
[
i
]
->
mutable_rows
()
->
clear
();
outs
[
i
]
->
mutable_value
()
->
mutable_data
<
T
>
(
dims
,
send_slr
.
place
());
outs
[
out_idx
]
->
set_height
(
rpc_ctx
.
height_sections
[
ctx
]);
outs
[
out_idx
]
->
mutable_rows
()
->
clear
();
outs
[
out_idx
]
->
mutable_value
()
->
mutable_data
<
T
>
(
dims
,
send_slr
.
place
());
if
(
rows_idx
.
size
()
>
0
)
{
for
(
auto
idx
:
rows_idx
)
{
outs
[
i
]
->
mutable_rows
()
->
push_back
(
idx
-
abs_sections
[
i
]);
outs
[
out_idx
]
->
mutable_rows
()
->
push_back
(
idx
-
abs_sections
[
ctx
]);
}
auto
dst
=
outs
[
i
]
->
mutable_value
()
->
mutable_data
<
T
>
(
place
);
auto
dst
=
outs
[
out_idx
]
->
mutable_value
()
->
mutable_data
<
T
>
(
place
);
for
(
size_t
j
=
0
;
j
<
rows_idx
.
size
();
j
++
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
memory
::
Copy
(
platform
::
CPUPlace
(),
dst
+
j
*
row_numel
,
platform
::
CPUPlace
(),
src
+
outs_dense_idx
[
i
][
j
]
*
row_numel
,
sizeof
(
T
)
*
row_numel
);
memory
::
Copy
(
platform
::
CPUPlace
(),
dst
+
j
*
row_numel
,
platform
::
CPUPlace
(),
src
+
outs_dense_idx
[
out_idx
][
j
]
*
row_numel
,
sizeof
(
T
)
*
row_numel
);
}
else
{
PADDLE_THROW
(
"do not support GPU now"
);
/*
#ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream);
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
*/
}
}
}
PADDLE_ENFORCE_EQ
(
rows_idx
.
size
(),
outs
[
i
]
->
rows
().
size
(),
PADDLE_ENFORCE_EQ
(
rows_idx
.
size
(),
outs
[
out_idx
]
->
rows
().
size
(),
"rows should has the same size with tensor dim 0"
);
}
}
else
{
PADDLE_THROW
(
"unsupported var type to send!"
);
}
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
send_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
auto
&
endpoint
=
rpc_ctx
.
epmap
[
i
];
if
(
NeedSend
(
*
local_scope
.
get
(),
send_var_name
))
{
VLOG
(
3
)
<<
"sending "
<<
send_var_name
<<
" to "
<<
endpoint
;
for
(
size_t
i
=
0
;
i
<
table_pairs
.
size
();
i
++
)
{
auto
&
send_var_name
=
table_pairs
[
i
].
second
;
auto
&
endpoint
=
table_pairs
[
i
].
first
;
auto
need_send
=
NeedSend
(
*
local_scope
.
get
(),
send_var_name
);
VLOG
(
4
)
<<
"send var name: "
<<
send_var_name
<<
"send var endpoint: "
<<
endpoint
<<
"need send: "
<<
need_send
;
if
(
need_send
)
{
VLOG
(
4
)
<<
"sending "
<<
send_var_name
<<
" to "
<<
endpoint
;
rets
.
push_back
(
rpc_client
->
AsyncSendVar
(
endpoint
,
cpu_ctx
,
*
local_scope
.
get
(),
send_var_name
));
VLOG
(
4
)
<<
"send var "
<<
send_var_name
<<
" async handle done"
;
}
else
{
VLOG
(
3
)
<<
"don't send non-initialized variable: "
VLOG
(
4
)
<<
"don't send non-initialized variable: "
<<
rpc_ctx
.
splited_var_names
[
i
];
}
}
}
else
{
PADDLE_THROW
(
"unsupported var type to send!"
);
}
VLOG
(
4
)
<<
"Prepare to send var "
<<
rpc_ctx
.
var_name
;
if
(
sync
)
{
for
(
auto
&
handle
:
rets
)
{
VLOG
(
4
)
<<
"Wait send var to pserver handle: "
<<
handle
;
PADDLE_ENFORCE
(
handle
->
Wait
(),
"internal error in RPCClient"
);
}
}
...
...
paddle/fluid/operators/distributed/parameter_send.h
浏览文件 @
728ec1b4
...
...
@@ -27,7 +27,7 @@ namespace distributed {
template
<
typename
T
>
struct
ParameterSend
{
void
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
,
bool
sync
);
bool
sync
,
int
multi_parts
);
};
};
// namespace distributed
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
728ec1b4
...
...
@@ -26,6 +26,7 @@
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -60,13 +61,26 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
varname
))
{
std
::
string
run_varname
=
varname
;
string
::
Piece
part_piece
(
"@PIECE"
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
part_piece
))
{
auto
varname_splits
=
paddle
::
string
::
Split
(
varname
,
'@'
);
PADDLE_ENFORCE_EQ
(
varname_splits
.
size
(),
3
);
run_varname
=
varname_splits
[
0
];
scope
->
Rename
(
varname
,
run_varname
);
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
auto
&
grad_slr
=
scope
->
FindVar
(
varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
varname
,
scope
->
FindVar
(
run_
varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_
varname
,
grad_slr
.
rows
());
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
run_
varname
].
get
(),
scope
);
return
true
;
}
else
{
// sync
...
...
@@ -116,9 +130,48 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
);
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
&
dims
=
origin_tensor
.
dims
();
*
outvar
=
scope
->
Var
();
auto
*
out_slr
=
(
*
outvar
)
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
set_rows
(
updated_rows
);
out_slr
->
set_height
(
dims
[
0
]);
auto
out_dims
=
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
updated_rows
.
size
()),
dims
[
1
]});
auto
*
data
=
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
(
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
for
(
auto
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
]);
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
}
}
else
{
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
return
true
;
}
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
728ec1b4
...
...
@@ -47,8 +47,12 @@ class SendOp : public framework::OperatorBase {
auto
height_sections
=
Attr
<
std
::
vector
<
int64_t
>>
(
"sections"
);
if
(
send_varnames
.
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
1
,
""
);
if
(
ins
.
size
()
>
1
)
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
,
send_varnames
,
scope
);
}
else
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
[
0
],
scope
);
}
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/pybind/communicator_py.cc
浏览文件 @
728ec1b4
...
...
@@ -15,8 +15,10 @@ limitations under the License. */
#include "paddle/fluid/pybind/communicator_py.h"
#include <Python.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "pybind11/pybind11.h"
...
...
@@ -27,6 +29,7 @@ namespace py = pybind11;
using
paddle
::
framework
::
ProgramDesc
;
using
paddle
::
operators
::
distributed
::
Communicator
;
using
paddle
::
operators
::
distributed
::
AsyncCommunicator
;
using
paddle
::
operators
::
distributed
::
GeoSgdCommunicator
;
using
paddle
::
framework
::
Scope
;
namespace
paddle
{
...
...
@@ -37,9 +40,20 @@ void BindCommunicator(py::module* m) {
py
::
class_
<
Communicator
,
std
::
shared_ptr
<
Communicator
>>
(
*
m
,
"DistCommunicator"
)
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
VLOG
(
0
)
<<
"using communicator"
;
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
int
&
trainers
,
int
&
geo_need_push_nums
)
{
VLOG
(
0
)
<<
"using geo sgd communicator"
;
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
"stop"
,
&
Communicator
::
Stop
)
.
def
(
"start"
,
&
Communicator
::
Start
)
.
def
(
"is_running"
,
&
Communicator
::
IsRunning
);
...
...
paddle/fluid/pybind/communicator_py.h
浏览文件 @
728ec1b4
...
...
@@ -16,7 +16,11 @@ limitations under the License. */
#include <Python.h>
#include "pybind11/chrono.h"
#include "pybind11/complex.h"
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
paddle
{
namespace
pybind
{
...
...
python/paddle/fluid/__init__.py
浏览文件 @
728ec1b4
...
...
@@ -195,6 +195,7 @@ def __bootstrap__():
read_env_flags
.
append
(
'communicator_min_send_grad_num_before_recv'
)
read_env_flags
.
append
(
'communicator_thread_pool_size'
)
read_env_flags
.
append
(
'communicator_max_merge_var_num'
)
read_env_flags
.
append
(
'communicator_merge_sparse_bucket'
)
read_env_flags
.
append
(
'communicator_fake_rpc'
)
read_env_flags
.
append
(
'communicator_send_wait_times'
)
read_env_flags
.
append
(
'communicator_merge_sparse_grad'
)
...
...
python/paddle/fluid/communicator.py
浏览文件 @
728ec1b4
...
...
@@ -13,6 +13,10 @@
# limitations under the License.
from
.executor
import
global_scope
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
"""
from
.
import
core
from
.framework
import
Program
...
...
@@ -20,7 +24,11 @@ __all__ = ['Communicator']
class
Communicator
(
object
):
def
__init__
(
self
,
program
):
def
__init__
(
self
,
program
,
vars_info
=
None
,
trainers
=
None
,
geo_sgd_need_push_nums
=
None
):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
...
...
@@ -47,7 +55,15 @@ class Communicator(object):
for
op
in
program
.
block
(
0
).
ops
:
if
op
.
type
==
"recv"
:
op
.
_set_attr
(
'do_not_run'
,
True
)
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
())
# Todo: Add check
if
vars_info
and
trainers
and
geo_sgd_need_push_nums
:
# for geo sgd
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
)
else
:
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
())
def
start
(
self
):
"""
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
728ec1b4
...
...
@@ -13,7 +13,9 @@
# limitations under the License.
import
os
import
warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
import
paddle.fluid.io
as
io
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.framework
import
default_main_program
...
...
@@ -24,6 +26,7 @@ from paddle.fluid.executor import Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.base.fleet_base
import
DistributedOptimizer
...
...
@@ -64,6 +67,12 @@ class DistributedTranspiler(Fleet):
wait_server_ready
(
fleet
.
server_endpoints
(
to_string
=
False
))
if
not
self
.
_transpile_config
.
sync_mode
:
if
self
.
_transpile_config
.
geo_sgd_mode
:
self
.
_communicator
=
Communicator
(
self
.
main_program
,
self
.
vars_info
,
fleet
.
worker_num
(),
self
.
_transpile_config
.
geo_sgd_need_push_nums
)
else
:
self
.
_communicator
=
Communicator
(
self
.
main_program
)
if
not
self
.
_communicator
.
is_running
():
...
...
@@ -124,7 +133,6 @@ class DistributedTranspiler(Fleet):
):
self
.
_communicator
.
stop
()
self
.
_executor
.
close
()
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
self
.
_role_maker
.
_finalize
()
...
...
@@ -239,6 +247,9 @@ class DistributedTranspiler(Fleet):
self
.
_origin_program
=
default_main_program
().
clone
(
for_test
=
False
)
self
.
_transpile_config
=
config
if
config
.
geo_sgd_mode
:
self
.
_transpiler
=
GeoSgdTranspiler
(
config
)
else
:
self
.
_transpiler
=
OriginTranspiler
(
config
)
if
self
.
is_worker
():
...
...
@@ -254,6 +265,9 @@ class DistributedTranspiler(Fleet):
self
.
main_program
=
self
.
_transpiler
.
get_trainer_program
(
wait_port
=
config
.
wait_port
)
self
.
startup_program
=
default_startup_program
()
if
self
.
_transpile_config
.
geo_sgd_mode
:
self
.
vars_info
=
self
.
_transpiler
.
_get_vars_info
()
self
.
startup_program
=
self
.
_transpiler
.
trainer_startup_program
else
:
self
.
_transpiler
.
transpile
(
trainer_id
=
fleet
.
worker_index
(),
...
...
@@ -262,7 +276,8 @@ class DistributedTranspiler(Fleet):
sync_mode
=
config
.
sync_mode
,
current_endpoint
=
self
.
server_endpoints
()[
self
.
server_index
()])
self
.
main_program
,
self
.
startup_program
=
\
self
.
_transpiler
.
get_pserver_programs
(
self
.
server_endpoints
()[
self
.
server_index
()])
self
.
_transpiler
.
get_pserver_programs
(
self
.
server_endpoints
()[
self
.
server_index
()])
fleet
=
DistributedTranspiler
()
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
728ec1b4
...
...
@@ -24,6 +24,7 @@ if(NOT WITH_DISTRIBUTE)
LIST
(
REMOVE_ITEM TEST_OPS test_nce_remote_table_op
)
LIST
(
REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op
)
LIST
(
REMOVE_ITEM TEST_OPS test_dist_fleet_ctr
)
LIST
(
REMOVE_ITEM TEST_OPS test_dist_fleet_geo
)
endif
(
NOT WITH_DISTRIBUTE
)
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
浏览文件 @
728ec1b4
...
...
@@ -13,7 +13,9 @@
# limitations under the License.
from
__future__
import
print_function
"""
high level unit test for distribute fleet.
"""
import
argparse
import
os
import
pickle
...
...
@@ -29,6 +31,7 @@ from contextlib import closing
import
six
import
unittest
import
numpy
as
np
import
tempfile
import
paddle.fluid
as
fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
...
...
@@ -40,6 +43,12 @@ LEARNING_RATE = 0.01
class
FleetDistRunnerBase
(
object
):
"""
run_pserver,run_trainer : after init role, using transpiler split program
net : implment by child class, the network of model
do training : exe run program
"""
def
run_pserver
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"PSERVER"
:
raise
ValueError
(
"args role must be PSERVER"
)
...
...
@@ -54,6 +63,8 @@ class FleetDistRunnerBase(object):
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
avg_cost
=
self
.
net
()
...
...
@@ -78,14 +89,14 @@ class FleetDistRunnerBase(object):
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
avg_cost
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
self
.
do_training
(
fleet
)
out
=
self
.
do_training
(
fleet
)
def
net
(
self
,
batch_size
=
4
,
lr
=
0.01
):
...
...
@@ -98,6 +109,11 @@ class FleetDistRunnerBase(object):
class
TestFleetBase
(
unittest
.
TestCase
):
"""
start_pserver,start_trainer : add start cmd to test
run_cluster : using multi process to test distribute program
"""
def
_setup_config
(
self
):
raise
NotImplementedError
(
"tests should have _setup_config implemented"
)
...
...
@@ -109,6 +125,8 @@ class TestFleetBase(unittest.TestCase):
self
.
_ps_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_python_interp
=
sys
.
executable
self
.
_geo_sgd
=
False
self
.
_geo_sgd_need_push_nums
=
5
self
.
_setup_config
()
def
_find_free_port
(
self
):
...
...
@@ -127,8 +145,8 @@ class TestFleetBase(unittest.TestCase):
def
_start_pserver
(
self
,
cmd
,
required_envs
):
ps0_cmd
,
ps1_cmd
=
cmd
.
format
(
0
),
cmd
.
format
(
1
)
ps0_pipe
=
open
(
"/tmp
/ps0_err.log"
,
"wb+"
)
ps1_pipe
=
open
(
"/tmp
/ps1_err.log"
,
"wb+"
)
ps0_pipe
=
open
(
tempfile
.
gettempdir
()
+
"
/ps0_err.log"
,
"wb+"
)
ps1_pipe
=
open
(
tempfile
.
gettempdir
()
+
"
/ps1_err.log"
,
"wb+"
)
ps0_proc
=
subprocess
.
Popen
(
ps0_cmd
.
strip
().
split
(
" "
),
...
...
@@ -140,14 +158,13 @@ class TestFleetBase(unittest.TestCase):
stdout
=
subprocess
.
PIPE
,
stderr
=
ps1_pipe
,
env
=
required_envs
)
return
ps0_proc
,
ps1_proc
,
ps0_pipe
,
ps1_pipe
def
_start_trainer
(
self
,
cmd
,
required_envs
):
tr0_cmd
,
tr1_cmd
=
cmd
.
format
(
0
),
cmd
.
format
(
1
)
tr0_pipe
=
open
(
"/tmp
/tr0_err.log"
,
"wb+"
)
tr1_pipe
=
open
(
"/tmp
/tr1_err.log"
,
"wb+"
)
tr0_pipe
=
open
(
tempfile
.
gettempdir
()
+
"
/tr0_err.log"
,
"wb+"
)
tr1_pipe
=
open
(
tempfile
.
gettempdir
()
+
"
/tr1_err.log"
,
"wb+"
)
tr0_proc
=
subprocess
.
Popen
(
tr0_cmd
.
strip
().
split
(
" "
),
...
...
@@ -164,18 +181,29 @@ class TestFleetBase(unittest.TestCase):
def
_run_cluster
(
self
,
model
,
envs
):
env
=
{
'CPU_NUM'
:
'1'
}
python_path
=
self
.
_python_interp
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
envs
[
'COVERAGE_FILE'
]
=
os
.
getenv
(
'COVERAGE_FILE'
,
''
)
python_path
+=
" -m coverage run --branch -p"
env
.
update
(
envs
)
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
if
self
.
_sync_mode
:
tr_cmd
+=
" --sync_mode"
ps_cmd
+=
" --sync_mode"
if
self
.
_geo_sgd
:
tr_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
ps_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
# Run dist train to compare with local results
ps0
,
ps1
,
ps0_pipe
,
ps1_pipe
=
self
.
_start_pserver
(
ps_cmd
,
env
)
tr0
,
tr1
,
tr0_pipe
,
tr1_pipe
=
self
.
_start_trainer
(
tr_cmd
,
env
)
...
...
@@ -259,7 +287,10 @@ def runtime_main(test_class):
parser
.
add_argument
(
'--current_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--geo_sgd_mode'
,
type
=
bool
,
required
=
False
,
default
=
False
)
parser
.
add_argument
(
'--geo_sgd_need_push_nums'
,
type
=
int
,
required
=
False
,
default
=
2
)
args
=
parser
.
parse_args
()
model
=
test_class
()
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
0 → 100644
浏览文件 @
728ec1b4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
test_dist_fleet_base
import
TestFleetBase
from
dist_simnet_bow
import
train_network
def
skip_ci
(
func
):
on_ci
=
bool
(
int
(
os
.
environ
.
get
(
"SKIP_UNSTABLE_CI"
,
'0'
)))
def
__func__
(
*
args
,
**
kwargs
):
if
on_ci
:
return
return
func
(
*
args
,
**
kwargs
)
return
__func__
class
TestDistGeoCtr_2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_geo_sgd
=
True
self
.
_geo_sgd_need_push_nums
=
5
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestGeoSgdTranspiler
(
unittest
.
TestCase
):
def
test_pserver
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
SERVER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:36011"
,
"127.0.0.1:36012"
])
fleet
.
init
(
role
)
batch_size
=
128
is_sparse
=
True
is_distribute
=
False
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
strategy
.
geo_sgd_mode
=
True
strategy
.
geo_sgd_need_push_nums
=
5
avg_cost
,
_
,
_
=
train_network
(
batch_size
,
is_distribute
,
is_sparse
)
optimizer
=
fluid
.
optimizer
.
SGD
(
0.1
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
pserver_startup_program
=
fleet
.
startup_program
pserver_mian_program
=
fleet
.
main_program
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
728ec1b4
...
...
@@ -180,6 +180,10 @@ class DistributeTranspilerConfig(object):
_runtime_split_send_recv
=
False
_sync_mode
=
True
# Geo-sgd algorithm
geo_sgd_mode
=
False
geo_sgd_need_push_nums
=
100
nccl_comm_num
=
1
#The picture here illustrates the principle:
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
...
...
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
0 → 100644
浏览文件 @
728ec1b4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
"""
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. create delta variable in global scope which used to send
3. add send op to send sparse ids to communicator
Steps to transpile pserver:
1. create new program for parameter server.
2. create params variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append sum ops that should run on current server instance.
5. add listen_and_serv op
"""
import
sys
import
collections
import
six
import
numpy
as
np
from
.ps_dispatcher
import
RoundRobin
,
PSDispatcher
from
..
import
core
,
framework
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
Block
,
Parameter
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
delete_ops
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
class
GeoSgdTranspiler
(
DistributeTranspiler
):
def
__init__
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
else
:
self
.
config
=
DistributeTranspilerConfig
()
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
assert
(
self
.
config
.
min_block_size
>=
8192
)
assert
(
self
.
config
.
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
def
transpile
(
self
,
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
sync_mode
=
False
,
startup_program
=
None
,
current_endpoint
=
"127.0.0.1:6174"
):
if
program
is
None
:
program
=
default_main_program
()
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
self
.
origin_program
=
program
self
.
startup_program
=
startup_program
self
.
origin_startup_program
=
self
.
startup_program
.
clone
()
self
.
trainer_num
=
trainers
# geo-sgd only supply async-mode
self
.
sync_mode
=
False
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
vars_overview
=
VarsDistributed
()
self
.
optimize_ops
,
self
.
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
self
.
config
.
split_method
(
self
.
pserver_endpoints
)
self
.
param_name_to_grad_name
=
dict
()
self
.
grad_name_to_param_name
=
dict
()
for
param_var
,
grad_var
in
self
.
params_grads
:
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
self
.
grad_name_to_param_name
[
grad_var
.
name
]
=
param_var
.
name
# distribute lookup table
self
.
table_name
=
find_distributed_lookup_table
(
self
.
origin_program
)
self
.
has_distributed_lookup_table
=
self
.
table_name
!=
None
self
.
origin_program
.
_distributed_lookup_table
=
self
.
table_name
if
self
.
table_name
else
None
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
self
.
origin_program
.
_ps_endpoint
=
current_endpoint
self
.
origin_program
.
_is_chief
=
self
.
trainer_id
==
0
# program info send to geo-sgd communicator
self
.
vars_info
=
collections
.
OrderedDict
()
self
.
split_to_origin_mapping
=
collections
.
OrderedDict
()
self
.
delta_vars_list
=
[]
self
.
sparse_var_list
=
[]
self
.
sparse_var_splited_list
=
[]
# split and create vars, then put splited vars in dicts for later use.
# step 1. split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
()
# step 3. create send recv var (param after optimize)
send_vars
=
[]
ps_dispatcher
.
reset
()
param_var_mapping_items
=
list
(
six
.
iteritems
(
self
.
param_var_mapping
))
# send_vars is the parameter which splited by communicator and send to pserver,not the origin parameter
for
_
,
splited_vars
in
param_var_mapping_items
:
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
recv_vars
=
send_vars
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_opt_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
distributed_var
=
self
.
vars_overview
.
get_distributed_var_by_slice
(
recv_vars
[
i
].
name
)
distributed_var
.
endpoint
=
ep
origin_name
=
self
.
split_to_origin_mapping
[
recv_vars
[
i
].
name
]
self
.
vars_info
[
origin_name
][
"epmap"
].
append
(
ep
)
self
.
origin_program
.
_parameters_on_pservers
=
self
.
vars_overview
# send sparse id to communicator
self
.
sparse_var
=
[]
self
.
sparse_tables
=
[]
for
op
in
self
.
origin_program
.
global_block
().
ops
:
if
op
.
type
==
"lookup_table"
:
op
.
_set_attr
(
'remote_prefetch'
,
False
)
for
input_var_name
,
sparse_var_name
in
zip
(
op
.
input
(
"Ids"
),
op
.
input
(
"W"
)):
if
sparse_var_name
in
self
.
sparse_var_list
:
input_var
=
program
.
global_block
().
var
(
input_var_name
)
self
.
sparse_var
.
append
(
input_var
)
self
.
sparse_tables
.
append
(
sparse_var_name
)
# batch training loop end flag
dummy_output
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
self
.
sparse_var
},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"send_varnames"
:
self
.
sparse_tables
})
# add param_init flag in trainer startup program
self
.
trainer_startup_program
=
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
for
delta_var
in
self
.
delta_vars_list
:
self
.
trainer_startup_program
.
global_block
().
create_var
(
name
=
delta_var
.
name
,
persistable
=
delta_var
.
persistable
,
dtype
=
delta_var
.
dtype
,
type
=
delta_var
.
type
,
shape
=
delta_var
.
shape
)
dummy_output
=
self
.
trainer_startup_program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
param_init
=
self
.
trainer_startup_program
.
global_block
().
create_var
(
name
=
"param_init"
)
self
.
trainer_startup_program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
[
param_init
]},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"send_varnames"
:
[
param_init
.
name
]})
def
_get_vars_info
(
self
):
return
self
.
vars_info
def
get_trainer_program
(
self
,
wait_port
=
True
):
# if wait_port:
# wait_server_ready(self.pserver_endpoints)
return
self
.
origin_program
def
get_pserver_programs
(
self
,
endpoint
):
pserver_prog
=
self
.
get_pserver_program
(
endpoint
)
self
.
param_grad_ep_mapping
=
self
.
param_opt_ep_mapping
pserver_startup
=
self
.
get_startup_program
(
endpoint
,
pserver_program
=
pserver_prog
)
return
pserver_prog
,
pserver_startup
def
get_pserver_program
(
self
,
endpoint
):
# step1
pserver_program
=
Program
()
pserver_program
.
random_seed
=
self
.
origin_program
.
random_seed
pserver_program
.
_copy_dist_param_info_from
(
self
.
origin_program
)
# step2: Create vars to receive vars at parameter servers.
recv_inputs
=
[]
for
v
in
self
.
param_opt_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
optimize_block
=
[]
param_to_block_id
=
[]
sparse_grad_to_param
=
[]
# append op to the current block
pre_block_idx
=
pserver_program
.
num_blocks
-
1
for
var
in
self
.
param_opt_ep_mapping
[
endpoint
][
"params"
]:
per_opt_block
=
pserver_program
.
_create_block
(
pre_block_idx
)
optimize_block
.
append
(
per_opt_block
)
var_name
=
var
.
name
pserver_block
=
per_opt_block
.
program
.
global_block
()
param
=
pserver_block
.
vars
[
var_name
]
delta_var_name
=
"%s.delta"
%
(
param
.
name
)
if
var
.
name
in
self
.
sparse_var_splited_list
:
delta_type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
sparse_grad_to_param
.
append
(
":"
.
join
(
[
delta_var_name
,
param
.
name
]))
else
:
delta_type
=
param
.
type
delta_var
=
pserver_block
.
create_var
(
name
=
delta_var_name
,
persistable
=
False
,
type
=
delta_type
,
dtype
=
param
.
dtype
,
shape
=
param
.
shape
)
per_opt_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
[
param
,
delta_var
]},
outputs
=
{
"Out"
:
param
})
param_to_block_id
.
append
(
delta_var_name
+
":"
+
str
(
per_opt_block
.
idx
))
attrs
=
{
"optimize_blocks"
:
optimize_block
,
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
param_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
}
# step5 append the listen_and_serv op
pserver_program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{
'X'
:
recv_inputs
},
outputs
=
{},
attrs
=
attrs
)
pserver_program
.
_sync_with_cpp
()
# save pserver program to generate pserver side startup relatively.
self
.
pserver_program
=
pserver_program
return
pserver_program
def
_init_splited_vars
(
self
):
param_list
=
[]
grad_list
=
[]
param_grad_set
=
set
()
# step 1. create param_list
for
p
,
g
in
self
.
params_grads
:
if
type
(
p
)
==
Parameter
and
p
.
trainable
==
False
:
continue
if
p
.
name
not
in
param_grad_set
:
param_list
.
append
(
p
)
param_grad_set
.
add
(
p
.
name
)
if
g
.
name
not
in
param_grad_set
:
grad_list
.
append
(
g
)
param_grad_set
.
add
(
g
.
name
)
if
g
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
self
.
sparse_var_list
.
append
(
p
.
name
)
# step 2. Slice vars into numbers of piece with block_size
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
param_blocks
=
slice_variable
(
param_list
,
len
(
self
.
pserver_endpoints
),
self
.
config
.
min_block_size
)
# step 3. Create splited param from split blocks
# origin_param_name -> [splited_param_vars]
# Todo: update _create_vars_from_blocklist
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
param_blocks
)
# step 4. Create mapping of endpoint -> split var to create pserver side program
self
.
param_opt_ep_mapping
=
collections
.
OrderedDict
()
[
self
.
param_opt_ep_mapping
.
update
({
ep
:
{
"params"
:
[],
}
})
for
ep
in
self
.
pserver_endpoints
]
# step 5. Create delta var of Geo-Sgd & record vars infomation
for
origin_name
,
splited_vars
in
self
.
param_var_mapping
.
items
():
origin_var
=
self
.
origin_program
.
global_block
().
var
(
origin_name
)
self
.
vars_info
[
origin_name
]
=
collections
.
OrderedDict
()
self
.
vars_info
[
origin_name
][
"var_names"
]
=
[]
vars_section
=
self
.
_get_splited_var_sections
(
splited_vars
)
self
.
vars_info
[
origin_name
][
"sections"
]
=
[
str
(
i
)
for
i
in
vars_section
]
self
.
vars_info
[
origin_name
][
"epmap"
]
=
[]
self
.
vars_info
[
origin_name
][
"is_sparse"
]
=
[]
# todo: add var shape(may be no need,because recv scope have)
if
origin_name
in
self
.
sparse_var_list
:
delta_type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
self
.
vars_info
[
origin_name
][
"is_sparse"
].
append
(
"True"
)
else
:
delta_type
=
origin_var
.
type
self
.
vars_info
[
origin_name
][
"is_sparse"
].
append
(
"False"
)
delta_var
=
self
.
origin_program
.
global_block
().
create_var
(
name
=
"."
.
join
([
origin_name
,
"delta"
]),
persistable
=
False
,
dtype
=
origin_var
.
dtype
,
type
=
delta_type
,
shape
=
origin_var
.
shape
)
self
.
delta_vars_list
.
append
(
delta_var
)
for
splited_var
in
splited_vars
:
is_slice
,
block_id
,
offset
=
self
.
_get_slice_var_info
(
splited_var
)
self
.
vars_overview
.
add_distributed_var
(
origin_var
=
origin_var
,
slice_var
=
splited_var
,
block_id
=
block_id
,
offset
=
offset
,
is_slice
=
is_slice
,
vtype
=
"Param"
)
self
.
split_to_origin_mapping
[
splited_var
.
name
]
=
origin_name
if
origin_name
in
self
.
sparse_var_list
:
self
.
sparse_var_splited_list
.
append
(
splited_var
.
name
)
self
.
vars_info
[
origin_name
][
"var_names"
].
append
(
splited_var
.
name
)
if
len
(
splited_vars
)
!=
1
:
self
.
origin_program
.
global_block
().
create_var
(
name
=
"."
.
join
([
splited_var
.
name
,
"delta"
]),
persistable
=
False
,
dtype
=
splited_var
.
dtype
,
type
=
delta_type
,
shape
=
splited_var
.
shape
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录