Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
371f377b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
371f377b
编写于
2月 02, 2020
作者:
X
xujiaqi01
提交者:
GitHub
2月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add GeneralRoleMaker (#22295)
* add GeneralRoleMaker which is for general usage * test=develop
上级
269db0d1
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
993 addition
and
63 deletion
+993
-63
Dockerfile
Dockerfile
+7
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+2
-0
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+2
-2
paddle/fluid/framework/dist_multi_trainer_test.cc
paddle/fluid/framework/dist_multi_trainer_test.cc
+56
-0
paddle/fluid/framework/fleet/gloo_wrapper.cc
paddle/fluid/framework/fleet/gloo_wrapper.cc
+30
-9
paddle/fluid/framework/fleet/gloo_wrapper.h
paddle/fluid/framework/fleet/gloo_wrapper.h
+5
-2
paddle/fluid/framework/fleet/test_fleet.cc
paddle/fluid/framework/fleet/test_fleet.cc
+1
-2
paddle/fluid/pybind/gloo_wrapper_py.cc
paddle/fluid/pybind/gloo_wrapper_py.cc
+3
-3
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+8
-8
python/paddle/fluid/incubate/fleet/base/fleet_base.py
python/paddle/fluid/incubate/fleet/base/fleet_base.py
+16
-0
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+477
-21
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
...e/fluid/incubate/fleet/parameter_server/pslib/__init__.py
+16
-6
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
+5
-5
python/paddle/fluid/tests/unittests/test_dataset.py
python/paddle/fluid/tests/unittests/test_dataset.py
+6
-2
python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py
python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py
+73
-3
python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py
...on/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py
+285
-0
未找到文件。
Dockerfile
浏览文件 @
371f377b
...
...
@@ -219,6 +219,13 @@ RUN wget -q https://launchpad.net/ubuntu/+archive/primary/+sourcefiles/binutils/
cd
binutils-2.27
&&
\
./configure
&&
make
-j
&&
make
install
&&
cd
..
&&
rm
-rf
binutils-2.27 binutils_2.27.orig.tar.gz
RUN
wget
--no-check-certificate
https://pslib.bj.bcebos.com/openmpi-1.4.5.tar.gz
&&
tar
-xzf
openmpi-1.4.5.tar.gz
&&
\
cd
openmpi-1.4.5
&&
./configure
--prefix
=
/usr/local
&&
make all
-j8
&&
make
install
-j8
&&
\
export
LD_LIBRARY_PATH
=
/usr/local/lib/:
$LD_LIBRARY_PATH
&&
export
PATH
=
/usr/local/bin:
$PATH
&&
cd
..
&&
\
rm
-rf
openmpi-1.4.5.tar.gz
&&
pip
--no-cache-dir
install
mpi4py
&&
ln
-fs
/bin/bash /bin/sh
&&
\
apt-get
install
libprotobuf-dev
-y
RUN
pip
--no-cache-dir
install
-U
netifaces
==
0.10.9
# Older versions of patchelf limited the size of the files being processed and were fixed in this pr.
# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa
# So install a newer version here.
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
371f377b
...
...
@@ -214,6 +214,7 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
cc_test
(
dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto boost
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
371f377b
...
...
@@ -287,6 +287,7 @@ void DatasetImpl<T>::LocalShuffle() {
template
<
typename
T
>
void
DatasetImpl
<
T
>::
GlobalShuffle
(
int
thread_num
)
{
#ifdef PADDLE_WITH_PSLIB
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
...
...
@@ -379,6 +380,7 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
#endif
}
template
<
typename
T
>
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
371f377b
...
...
@@ -41,8 +41,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
need_dump_field_
=
false
;
}
}
mpi_rank_
=
trainer_desc
.
mpi_rank
()
/
2
;
mpi_size_
=
trainer_desc
.
mpi_size
()
/
2
;
mpi_rank_
=
trainer_desc
.
mpi_rank
();
mpi_size_
=
trainer_desc
.
mpi_size
();
dump_file_num_
=
trainer_desc
.
dump_file_num
();
const
std
::
vector
<
paddle
::
framework
::
DataFeed
*>
readers
=
dataset
->
GetReaders
();
...
...
paddle/fluid/framework/dist_multi_trainer_test.cc
0 → 100644
浏览文件 @
371f377b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <sstream>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/trainer.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
namespace
paddle
{
namespace
framework
{
TEST
(
DisMultiTrainerTest
,
test1
)
{
#ifdef _LINUX
std
::
shared_ptr
<
DistMultiTrainer
>
tmp1
=
std
::
make_shared
<
DistMultiTrainer
>
();
TrainerDesc
t
;
t
.
set_class_name
(
"DistMultiTrainer"
);
t
.
set_device_worker_name
(
"DownpourWorker"
);
t
.
set_thread_num
(
1
);
auto
*
m
=
t
.
mutable_downpour_param
()
->
add_program_config
();
m
->
set_program_id
(
"123"
);
std
::
string
str
;
str
+=
"name:
\"
MultiSlotDataFeed
\"\n
batch_size: 2
\n
multi_slot_desc {
\n
"
;
str
+=
"slots {
\n
name:
\"
words
\"\n
type:
\"
uint64
\"\n
is_dense: false
\n
"
;
str
+=
"is_used: true
\n
}
\n
slots {
\n
name:
\"
label
\"\n
type:
\"
uint64
\"\n
"
;
str
+=
"is_dense: false
\n
is_used: true
\n
}
\n
}
\n
"
;
std
::
shared_ptr
<
MultiSlotDataset
>
dataset
=
std
::
make_shared
<
MultiSlotDataset
>
();
dataset
->
SetFileList
(
std
::
vector
<
std
::
string
>
());
dataset
->
SetThreadNum
(
1
);
dataset
->
SetTrainerNum
(
1
);
dataset
->
SetDataFeedDesc
(
str
);
dataset
->
CreateReaders
();
tmp1
->
Initialize
(
t
,
dataset
.
get
());
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/fleet/gloo_wrapper.cc
浏览文件 @
371f377b
...
...
@@ -21,6 +21,7 @@ HdfsStore::HdfsStore(const std::string& path) {
path_
=
path
;
wait_sleep_ms_
=
3000
;
wait_timeout_
=
std
::
chrono
::
seconds
(
999999999
);
retry_times_
=
100
;
}
void
HdfsStore
::
set
(
const
std
::
string
&
key
,
const
std
::
vector
<
char
>&
data
)
{
...
...
@@ -33,10 +34,27 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
paddle
::
framework
::
fs_remove
(
path
);
}
int
err_no
=
0
;
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_write
(
tmp
,
&
err_no
,
""
);
size_t
write_count
=
fwrite_unlocked
(
data
.
data
(),
1
,
data
.
size
(),
fp
.
get
());
VLOG
(
3
)
<<
"HdfsStore::set write_count="
<<
write_count
<<
" key "
<<
key
;
fp
.
reset
();
for
(
int
i
=
1
;
i
<=
retry_times_
;
++
i
)
{
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_write
(
tmp
,
&
err_no
,
""
);
if
(
err_no
!=
0
)
{
VLOG
(
0
)
<<
"fs_open_write failed, retry times "
<<
i
<<
" err no "
<<
err_no
;
fp
.
reset
();
sleep
(
wait_sleep_ms_
/
1000
);
continue
;
}
size_t
write_count
=
fwrite_unlocked
(
data
.
data
(),
1
,
data
.
size
(),
fp
.
get
());
if
(
write_count
!=
data
.
size
())
{
VLOG
(
0
)
<<
"fwrite_unlocked failed, retry times "
<<
i
<<
" write_count "
<<
write_count
<<
" data.size() "
<<
data
.
size
();
fp
.
reset
();
sleep
(
2
);
continue
;
}
fp
.
reset
();
break
;
}
paddle
::
framework
::
fs_mv
(
tmp
,
path
);
#endif
}
...
...
@@ -131,7 +149,7 @@ void GlooWrapper::Init(int rank, int size, const std::string& path,
}
rank_
=
rank
;
size_
=
size
;
std
::
string
cmd
=
std
::
string
(
"hadoop fs"
);
std
::
string
cmd
=
std
::
string
(
"
${HADOOP_HOME}/bin/
hadoop fs"
);
cmd
+=
" -D fs.default.name="
+
fs_name
;
cmd
+=
" -D hadoop.job.ugi="
+
fs_ugi
;
paddle
::
framework
::
hdfs_set_command
(
cmd
);
...
...
@@ -149,16 +167,19 @@ void GlooWrapper::Init(int rank, int size, const std::string& path,
is_initialized_
=
true
;
}
template
void
GlooWrapper
::
AllReduce
<
int64_t
>(
template
std
::
vector
<
int64_t
>
GlooWrapper
::
AllReduce
<
int64_t
>
(
std
::
vector
<
int64_t
>&
sendbuf
,
// NOLINT
std
::
vector
<
int64_t
>&
recvbuf
,
// NOLINT
const
std
::
string
&
mode
);
template
void
GlooWrapper
::
AllReduce
<
double
>(
template
std
::
vector
<
double
>
GlooWrapper
::
AllReduce
<
double
>
(
std
::
vector
<
double
>&
sendbuf
,
// NOLINT
std
::
vector
<
double
>&
recvbuf
,
// NOLINT
const
std
::
string
&
mode
);
template
std
::
vector
<
uint64_t
>
GlooWrapper
::
AllReduce
<
uint64_t
>
(
std
::
vector
<
uint64_t
>&
sendbuf
,
// NOLINT
const
std
::
string
&
mode
);
template
std
::
vector
<
int64_t
>
GlooWrapper
::
AllGather
<
int64_t
>
(
int64_t
&
input
);
// NOLINT
template
std
::
vector
<
uint64_t
>
GlooWrapper
::
AllGather
<
uint64_t
>
(
uint64_t
&
input
);
// NOLINT
template
std
::
vector
<
double
>
GlooWrapper
::
AllGather
<
double
>
(
double
&
input
);
// NOLINT
...
...
paddle/fluid/framework/fleet/gloo_wrapper.h
浏览文件 @
371f377b
...
...
@@ -70,6 +70,7 @@ class HdfsStore {
std
::
string
path_
;
int
wait_sleep_ms_
;
std
::
chrono
::
seconds
wait_timeout_
;
int
retry_times_
;
};
}
// namespace rendezvous
...
...
@@ -107,9 +108,10 @@ class GlooWrapper {
}
template
<
typename
T
>
void
AllReduce
(
std
::
vector
<
T
>&
sendbuf
,
std
::
vector
<
T
>&
recvbuf
,
// NOLINT
const
std
::
string
&
mode
=
"sum"
)
{
std
::
vector
<
T
>
AllReduce
(
std
::
vector
<
T
>&
sendbuf
,
// NOLINT
const
std
::
string
&
mode
=
"sum"
)
{
// NOLINT
CHECK_EQ
(
is_initialized_
,
true
);
std
::
vector
<
T
>
recvbuf
(
sendbuf
.
size
(),
T
());
CHECK_EQ
(
sendbuf
.
size
()
==
recvbuf
.
size
(),
true
);
#ifdef PADDLE_WITH_GLOO
gloo
::
AllreduceOptions
opts
(
context_
);
...
...
@@ -133,6 +135,7 @@ class GlooWrapper {
}
gloo
::
allreduce
(
opts
);
#endif
return
recvbuf
;
}
template
<
typename
T
>
...
...
paddle/fluid/framework/fleet/test_fleet.cc
浏览文件 @
371f377b
...
...
@@ -49,8 +49,7 @@ TEST(TEST_GLOO, store_1) {
gw
.
Size
();
gw
.
Barrier
();
std
::
vector
<
double
>
input
;
std
::
vector
<
double
>
output
;
gw
.
AllReduce
(
input
,
output
);
gw
.
AllReduce
(
input
);
int64_t
t
;
gw
.
AllGather
(
t
);
#endif
...
...
paddle/fluid/pybind/gloo_wrapper_py.cc
浏览文件 @
371f377b
...
...
@@ -37,12 +37,12 @@ void BindGlooWrapper(py::module* m) {
.
def
(
"rank"
,
&
framework
::
GlooWrapper
::
Rank
)
.
def
(
"size"
,
&
framework
::
GlooWrapper
::
Size
)
.
def
(
"barrier"
,
&
framework
::
GlooWrapper
::
Barrier
)
.
def
(
"all_reduce"
,
&
framework
::
GlooWrapper
::
AllReduce
<
uint64_t
>
)
.
def
(
"all_reduce"
,
&
framework
::
GlooWrapper
::
AllReduce
<
int64_t
>
)
.
def
(
"all_reduce"
,
&
framework
::
GlooWrapper
::
AllReduce
<
double
>
)
.
def
(
"all_gather"
,
&
framework
::
GlooWrapper
::
AllGather
<
uint64_t
>
)
.
def
(
"all_gather"
,
&
framework
::
GlooWrapper
::
AllGather
<
int64_t
>
)
.
def
(
"all_gather"
,
&
framework
::
GlooWrapper
::
AllGather
<
double
>
)
.
def
(
"Allreduce"
,
&
framework
::
GlooWrapper
::
AllReduce
<
int64_t
>
)
.
def
(
"Allreduce"
,
&
framework
::
GlooWrapper
::
AllReduce
<
double
>
);
.
def
(
"all_gather"
,
&
framework
::
GlooWrapper
::
AllGather
<
double
>
);
}
// end BindGlooWrapper
}
// end namespace pybind
}
// end namespace paddle
python/paddle/fluid/dataset.py
浏览文件 @
371f377b
...
...
@@ -526,7 +526,7 @@ class InMemoryDataset(DatasetBase):
"""
trainer_num
=
1
if
fleet
is
not
None
:
fleet
.
_role_maker
.
_
barrier_worker
()
fleet
.
_role_maker
.
barrier_worker
()
trainer_num
=
fleet
.
worker_num
()
if
self
.
fleet_send_batch_size
is
None
:
self
.
fleet_send_batch_size
=
1024
...
...
@@ -537,14 +537,14 @@ class InMemoryDataset(DatasetBase):
self
.
dataset
.
set_fleet_send_batch_size
(
self
.
fleet_send_batch_size
)
self
.
dataset
.
set_fleet_send_sleep_seconds
(
self
.
fleet_send_sleep_seconds
)
if
fleet
is
not
None
:
fleet
.
_role_maker
.
_
barrier_worker
()
fleet
.
_role_maker
.
barrier_worker
()
self
.
dataset
.
global_shuffle
(
thread_num
)
if
fleet
is
not
None
:
fleet
.
_role_maker
.
_
barrier_worker
()
fleet
.
_role_maker
.
barrier_worker
()
if
self
.
merge_by_lineid
:
self
.
dataset
.
merge_by_lineid
()
if
fleet
is
not
None
:
fleet
.
_role_maker
.
_
barrier_worker
()
fleet
.
_role_maker
.
barrier_worker
()
def
release_memory
(
self
):
"""
...
...
@@ -599,8 +599,8 @@ class InMemoryDataset(DatasetBase):
local_data_size
=
np
.
array
([
local_data_size
])
if
fleet
is
not
None
:
global_data_size
=
local_data_size
*
0
fleet
.
_role_maker
.
_node_type_comm
.
Allreduce
(
local_data_size
,
global_data_size
)
fleet
.
_role_maker
.
all_reduce_worker
(
local_data_size
,
global_data_size
)
return
global_data_size
[
0
]
return
local_data_size
[
0
]
...
...
@@ -637,8 +637,8 @@ class InMemoryDataset(DatasetBase):
local_data_size
=
np
.
array
([
local_data_size
])
if
fleet
is
not
None
:
global_data_size
=
local_data_size
*
0
fleet
.
_role_maker
.
_node_type_comm
.
Allreduce
(
local_data_size
,
global_data_size
)
fleet
.
_role_maker
.
all_reduce_worker
(
local_data_size
,
global_data_size
)
return
global_data_size
[
0
]
return
local_data_size
[
0
]
...
...
python/paddle/fluid/incubate/fleet/base/fleet_base.py
浏览文件 @
371f377b
...
...
@@ -202,6 +202,22 @@ class Fleet(object):
self
.
_role_maker
.
generate_role
()
self
.
_is_initialized
=
True
def
all_reduce_worker
(
self
,
input
,
output
):
"""
all reduce between workers, only support array of one dim.
Args:
input(list|numpy.array): array of one dim
output(list|numpy.array): array of one dim
"""
self
.
_role_maker
.
all_reduce_worker
(
input
,
output
)
def
barrier_worker
(
self
):
"""
barrier between workers
"""
self
.
_role_maker
.
barrier_worker
()
@
abc
.
abstractmethod
def
init_worker
(
self
):
pass
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
371f377b
...
...
@@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defination of Role Makers."""
from
__future__
import
print_function
import
paddle.fluid
as
fluid
import
os
import
time
__all__
=
[
'Role'
,
'RoleMakerBase'
,
'MPISymetricRoleMaker'
,
'UserDefinedRoleMaker'
,
'UserDefinedCollectiveRoleMaker'
,
'PaddleCloudRoleMaker'
'UserDefinedCollectiveRoleMaker'
,
'PaddleCloudRoleMaker'
,
'GeneralRoleMaker'
]
import
os
class
Role
:
WORKER
=
1
...
...
@@ -107,6 +109,43 @@ class RoleMakerBase(object):
self
.
_role
,
self
.
_current_id
,
self
.
_worker_endpoints
,
self
.
_server_endpoints
)
def
all_gather
(
self
,
input
):
"""
all gather between trainers and pservers
Args:
input(int|float): input value
Returns:
return a list of values
"""
print
(
"warning: RoleMakerBase does not have all gather."
)
return
None
def
all_reduce_worker
(
self
,
input
,
output
,
mode
=
"sum"
):
"""
all reduce between trainers if current role is TRAINER,
only support array of one dim.
Args:
input(list/numpy.array): array of one dim
output(list/numpy.array): array of one dim
mode(str): "sum" or "min" or "max"
"""
print
(
"warning: RoleMakerBase does not have all reduce worker."
)
def
barrier_worker
(
self
):
"""
barrier between trainers if current role is TRAINER
"""
print
(
"warning: RoleMakerBase does not have barrier worker."
)
def
barrier_all
(
self
):
"""
barrier between trainers if current role is PSERVER
"""
print
(
"warning: RoleMakerBase does not have barrier all."
)
class
MPIRoleMaker
(
RoleMakerBase
):
"""
...
...
@@ -115,6 +154,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
def
__init__
(
self
):
"""Init."""
super
(
MPIRoleMaker
,
self
).
__init__
()
from
mpi4py
import
MPI
self
.
MPI
=
MPI
...
...
@@ -124,16 +164,12 @@ class MPIRoleMaker(RoleMakerBase):
self
.
_ip
=
None
def
_get_rank
(
self
):
"""
return rank
"""
"""Return rank."""
self
.
_rank
=
self
.
_comm
.
Get_rank
()
return
self
.
_rank
def
_get_size
(
self
):
"""
return size
"""
"""Return size."""
self
.
_size
=
self
.
_comm
.
Get_size
()
return
self
.
_size
...
...
@@ -174,9 +210,7 @@ class MPIRoleMaker(RoleMakerBase):
return
self
.
_ips
def
get_local_ip
(
self
):
"""
return get local ip
"""
"""Return get local ip."""
import
socket
self
.
_ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
return
self
.
_ip
...
...
@@ -196,16 +230,68 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
def
__init__
(
self
):
"""Init."""
super
(
MPISymetricRoleMaker
,
self
).
__init__
()
self
.
_node_type
=
None
self
.
_proc_per_node
=
2
self
.
_pserver_rand_port
=
0
def
_check_role_generation
(
self
):
"""Check whether role has been generated."""
if
not
self
.
_role_is_generated
:
raise
NameError
(
"generate_role() should be called first"
)
return
True
def
all_gather
(
self
,
input
):
"""
all gather between trainers and pservers
Args:
input(int|float): input value
Returns:
return a list of values
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_all_gather
(
input
)
def
all_reduce_worker
(
self
,
input
,
output
,
mode
=
"sum"
):
"""
all reduce between trainers if current role is TRAINER,
only support array of one dim.
Args:
input(list/numpy.array): array of one dim
output(list/numpy.array): array of one dim
mode(str): "sum" or "min" or "max"
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
not
self
.
is_worker
():
print
(
"warning: current role is not worker in all_reduce_worker"
)
return
self
.
_all_reduce
(
input
,
output
,
mode
)
def
barrier_worker
(
self
):
"""
barrier between trainers if current role is TRAINER
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
self
.
is_worker
():
self
.
_node_type_comm
.
barrier
()
else
:
print
(
"warning: current role is not worker in barrier_worker"
)
def
barrier_all
(
self
):
"""
barrier between trainers if current role is PSERVER
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
self
.
_comm
.
barrier
()
def
is_first_worker
(
self
):
"""
return whether current process is the first worker assigned by role maker
...
...
@@ -215,6 +301,12 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return
False
def
get_pserver_endpoints
(
self
):
"""
get pserver endpoints
Returns:
endpoints(list): pserver endpoints
"""
if
self
.
_pserver_rand_port
<=
0
:
import
random
random
.
seed
(
self
.
_server_num
())
...
...
@@ -285,6 +377,28 @@ class MPISymetricRoleMaker(MPIRoleMaker):
self
.
generate_role
()
return
self
.
_get_size
()
/
self
.
_proc_per_node
def
_all_reduce
(
self
,
input
,
output
,
mode
=
"sum"
):
"""
all reduce between trainers if current role is TRAINER,
only support array of one dim.
Args:
input(list/numpy.array): array of one dim
output(list/numpy.array): array of one dim
mode(str): "sum" or "min" or "max"
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
mode
==
"sum"
:
mode
=
self
.
MPI
.
SUM
elif
mode
==
"max"
:
mode
=
self
.
MPI
.
MAX
elif
mode
==
"min"
:
mode
=
self
.
MPI
.
MIN
else
:
raise
ValueError
(
"unknown mode: %s"
%
mode
)
self
.
_node_type_comm
.
Allreduce
(
input
,
output
,
op
=
mode
)
def
_barrier_worker
(
self
):
"""
barrier all workers in current distributed job
...
...
@@ -325,12 +439,18 @@ class MPISymetricRoleMaker(MPIRoleMaker):
class
PaddleCloudRoleMaker
(
RoleMakerBase
):
"""
role maker for paddle cloud,
base class is RoleMakerBase
"""
def
__init__
(
self
,
is_collective
=
False
):
super
(
PaddleCloudRoleMaker
,
self
).
__init__
()
self
.
_role_is_generated
=
False
self
.
_is_collective
=
is_collective
def
generate_role
(
self
):
"""Generate role."""
if
not
self
.
_role_is_generated
:
if
not
self
.
_is_collective
:
try
:
...
...
@@ -419,17 +539,352 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return
self
.
_trainers_num
class
GeneralRoleMaker
(
RoleMakerBase
):
"""
This role maker is for general use, you can set os.environ to customize:
PADDLE_PSERVERS_IP_PORT_LIST : all pservers' ip:port, seperated by ','
PADDLE_TRAINER_ENDPOINTS : all trainers' ip:port, seperated by ','
TRAINING_ROLE : TRAINER or PSERVER
PADDLE_TRAINER_ID : current trainer id (only for trainer),
it is index in PADDLE_TRAINER_ENDPOINTS
PADDLE_PSERVER_ID : current pserver id (only for pserver)
it is index in PADDLE_PSERVERS_IP_PORT_LIST
"""
def
__init__
(
self
,
**
kwargs
):
super
(
RoleMakerBase
,
self
).
__init__
()
self
.
_role_is_generated
=
False
self
.
_hdfs_name
=
kwargs
.
get
(
"hdfs_name"
,
""
)
self
.
_hdfs_ugi
=
kwargs
.
get
(
"hdfs_ugi"
,
""
)
self
.
_hdfs_path
=
kwargs
.
get
(
"path"
,
""
)
self
.
_iface
=
self
.
__get_default_iface
()
# this environment variable can be empty
self
.
_prefix
=
os
.
getenv
(
"SYS_JOB_ID"
,
""
)
def
generate_role
(
self
):
"""
generate role for general role maker
"""
if
not
self
.
_role_is_generated
:
eplist
=
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
].
split
(
","
)
training_role
=
os
.
environ
[
"TRAINING_ROLE"
]
worker_endpoints
=
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
].
split
(
","
)
trainers_num
=
len
(
worker_endpoints
)
if
training_role
not
in
[
"TRAINER"
,
"PSERVER"
]:
raise
ValueError
(
"TRAINING_ROLE must be PSERVER or TRAINER"
)
if
training_role
==
"TRAINER"
:
role
=
Role
.
WORKER
current_id
=
int
(
os
.
environ
[
"PADDLE_TRAINER_ID"
])
self
.
_node_type
=
1
self
.
_cur_endpoint
=
worker_endpoints
[
current_id
]
gloo
=
fluid
.
core
.
Gloo
()
gloo
.
init
(
current_id
,
len
(
worker_endpoints
),
self
.
_hdfs_path
.
rstrip
(
"/"
)
+
"/trainer"
,
self
.
_hdfs_name
,
self
.
_hdfs_ugi
,
self
.
_iface
,
self
.
_prefix
)
self
.
_node_type_comm
=
gloo
elif
training_role
==
"PSERVER"
:
role
=
Role
.
SERVER
if
os
.
environ
.
get
(
"PADDLE_PSERVER_ID"
)
is
not
None
:
current_id
=
int
(
os
.
environ
[
"PADDLE_PSERVER_ID"
])
cur_endpoint
=
eplist
[
current_id
]
else
:
# this is for compatible with paddlecloud
cur_ip
=
os
.
environ
[
"POD_IP"
]
cur_port
=
os
.
environ
[
"PADDLE_PORT"
]
cur_endpoint
=
":"
.
join
([
cur_ip
,
cur_port
])
current_id
=
eplist
.
index
(
cur_endpoint
)
self
.
_node_type
=
0
self
.
_cur_endpoint
=
cur_endpoint
gloo
=
fluid
.
core
.
Gloo
()
gloo
.
init
(
current_id
,
len
(
eplist
),
self
.
_hdfs_path
.
rstrip
(
"/"
)
+
"/pserver"
,
self
.
_hdfs_name
,
self
.
_hdfs_ugi
,
self
.
_iface
,
self
.
_prefix
)
self
.
_node_type_comm
=
gloo
gloo
=
fluid
.
core
.
Gloo
()
all_list
=
worker_endpoints
+
eplist
gloo
.
init
(
all_list
.
index
(
self
.
_cur_endpoint
),
len
(
all_list
),
self
.
_hdfs_path
.
rstrip
(
"/"
)
+
"/all"
,
self
.
_hdfs_name
,
self
.
_hdfs_ugi
,
self
.
_iface
,
self
.
_prefix
)
self
.
_all_comm
=
gloo
self
.
_trainers_num
=
trainers_num
self
.
_server_endpoints
=
eplist
self
.
_role
=
role
self
.
_current_id
=
current_id
self
.
_rank
=
all_list
.
index
(
self
.
_cur_endpoint
)
self
.
_size
=
len
(
all_list
)
self
.
_worker_endpoints
=
worker_endpoints
self
.
_role_is_generated
=
True
def
all_gather
(
self
,
input
):
"""
all gather between trainers and pservers
Args:
input(int|float): input value
Returns:
return a list of values
"""
return
self
.
_all_gather
(
input
)
def
all_reduce_worker
(
self
,
input
,
output
,
mode
=
"sum"
):
"""
all reduce between trainers if current role is TRAINER,
only support array of one dim.
Args:
input(list/numpy.array): array of one dim
output(list/numpy.array): array of one dim
mode(str): "sum" or "min" or "max"
"""
if
not
self
.
is_worker
():
return
self
.
_all_reduce
(
input
,
output
,
mode
)
def
barrier_worker
(
self
):
"""
barrier between trainers if current role is TRAINER
"""
self
.
_barrier_worker
()
def
barrier_all
(
self
):
"""
barrier between trainers if current role is PSERVER
"""
self
.
_barrier_all
()
def
get_local_endpoint
(
self
):
"""
get local endpoint of current process
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_cur_endpoint
def
get_trainer_endpoints
(
self
):
"""
get endpoint of all trainers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_worker_endpoints
def
get_pserver_endpoints
(
self
):
"""
get endpoint of all pservers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_server_endpoints
def
is_worker
(
self
):
"""
whether current process is worker
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_role
==
Role
.
WORKER
def
is_server
(
self
):
"""
whether current process is server
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_role
==
Role
.
SERVER
def
is_first_worker
(
self
):
"""
whether current process is worker of rank 0
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_role
==
Role
.
WORKER
and
self
.
_current_id
==
0
def
worker_index
(
self
):
"""
get index of current worker
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_current_id
def
server_index
(
self
):
"""
get index of current server
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_current_id
def
worker_num
(
self
):
"""
retrun the current number of worker
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_worker_num
()
def
server_num
(
self
):
"""
return the current number of server
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_server_num
()
def
_barrier_worker
(
self
):
"""
barrier all workers in current distributed job
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
self
.
is_worker
():
self
.
_node_type_comm
.
barrier
()
def
_barrier_all
(
self
):
"""
barrier all workers and servers in current distributed job
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
self
.
_all_comm
.
barrier
()
def
_barrier_server
(
self
):
"""
barrier all servers in current distributed job
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
self
.
is_server
():
self
.
_node_type_comm
.
barrier
()
def
_worker_num
(
self
):
"""
return the current number of worker
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_trainers_num
def
_server_num
(
self
):
"""
return the current number of server
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
len
(
self
.
_server_endpoints
)
def
_finalize
(
self
):
"""Default do nothing."""
pass
def
_all_reduce
(
self
,
input
,
output
,
mode
=
"sum"
):
"""
all reduce between all workers
Args:
input(list|numpy.array): array of one dim
output(list|numpy.array): array of one dim
mode(str): "sum" or "min" or "max"
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
input_list
=
[
i
for
i
in
input
]
ans
=
self
.
_node_type_comm
.
all_reduce
(
input_list
,
mode
)
for
i
in
range
(
len
(
ans
)):
output
[
i
]
=
ans
[
i
]
def
_all_gather
(
self
,
obj
):
"""
gather between all workers and pservers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
self
.
_barrier_all
()
return
self
.
_all_comm
.
all_gather
(
obj
)
def
_worker_gather
(
self
,
obj
):
"""
gather between all workers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
if
not
self
.
is_worker
():
return
None
self
.
_barrier_worker
()
return
self
.
_node_type_comm
.
all_gather
(
obj
)
def
_get_rank
(
self
):
"""
get current rank in all workers and pservers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_rank
def
_get_size
(
self
):
"""
get total num of all workers and pservers
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_size
def
__get_default_iface
(
self
):
"""
get default physical interface
"""
default1
=
self
.
__get_default_iface_from_gateway
()
default2
=
self
.
__get_default_iface_from_interfaces
()
return
default2
if
default1
==
"lo"
else
default1
def
__get_default_iface_from_gateway
(
self
):
"""
get default physical interface
"""
import
netifaces
gateways
=
netifaces
.
gateways
()
if
gateways
.
get
(
netifaces
.
AF_INET
)
!=
None
:
gateway
=
gateways
[
netifaces
.
AF_INET
]
if
len
(
gateway
)
>
0
and
len
(
gateway
[
0
])
>
1
:
return
gateway
[
0
][
1
]
return
"lo"
def
__get_default_iface_from_interfaces
(
self
):
"""
get default physical interface
"""
import
netifaces
for
intf_name
in
netifaces
.
interfaces
():
addresses
=
netifaces
.
ifaddresses
(
intf_name
)
if
netifaces
.
AF_INET
in
addresses
:
ipv4_addresses
=
addresses
[
netifaces
.
AF_INET
]
for
ipv4_address
in
ipv4_addresses
:
if
'broadcast'
in
ipv4_address
:
return
intf_name
return
"lo"
class
UserDefinedRoleMaker
(
RoleMakerBase
):
"""
UserDefinedRoleMaker is designed for worker and server assignment
under manual. Typically, a worker and a server node will be appointed
on each physical node, It can be assign by user.
"""
def
__init__
(
self
,
current_id
=
0
,
role
=
Role
.
WORKER
,
worker_num
=
0
,
server_endpoints
=
None
):
"""
UserDefinedRoleMaker is designed for worker and server assignment
under manual. Typically, a worker and a server node will be appointed
on each physical node, It can be assign by user.
"""
super
(
UserDefinedRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
server_endpoints
,
list
):
...
...
@@ -495,11 +950,12 @@ class UserDefinedRoleMaker(RoleMakerBase):
class
UserDefinedCollectiveRoleMaker
(
RoleMakerBase
):
"""
UserDefinedCollectiveRoleMaker is designed for worker assignment
under manual for collective mode.
"""
def
__init__
(
self
,
current_id
=
0
,
worker_endpoints
=
None
):
"""
UserDefinedCollectiveRoleMaker is designed for worker assignment
under manual for collective mode.
"""
super
(
UserDefinedCollectiveRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
worker_endpoints
,
list
):
...
...
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
浏览文件 @
371f377b
...
...
@@ -40,7 +40,9 @@ class PSLib(Fleet):
self
.
_client2client_max_retry
=
3
def
init
(
self
,
role_maker
=
None
):
super
(
PSLib
,
self
).
init
(
MPISymetricRoleMaker
())
if
role_maker
is
None
:
role_maker
=
MPISymetricRoleMaker
()
super
(
PSLib
,
self
).
init
(
role_maker
)
self
.
_fleet_ptr
=
fluid
.
core
.
Fleet
()
def
_set_client_communication_config
(
self
,
request_timeout_ms
,
...
...
@@ -75,9 +77,10 @@ class PSLib(Fleet):
# barrier_all for init_server, wait for server starts
self
.
_role_maker
.
_barrier_all
()
self
.
all_ips_
=
self
.
_role_maker
.
_all_gather
(
self
.
_local_ip
)
# worker_index * 2 is for compatible with older versions of pslib
self
.
_fleet_ptr
.
init_worker
(
self
.
_dist_desc_str
,
self
.
all_ips_
,
self
.
_role_maker
.
_get_size
(),
self
.
_role_maker
.
_get_rank
()
)
self
.
_role_maker
.
worker_index
()
*
2
)
# barrier_all for init_worker
self
.
_role_maker
.
_barrier_all
()
# prepare for client to client communication
...
...
@@ -160,9 +163,16 @@ class PSLib(Fleet):
else
:
raise
Exception
(
"You should run DistributedOptimizer.minimize() first"
)
# server_index * 2 is for compatible with older versions of pslib
self
.
_fleet_ptr
.
init_server
(
self
.
_dist_desc_str
,
self
.
_role_maker
.
_get_rank
())
self
.
_local_ip
=
self
.
_fleet_ptr
.
run_server
()
self
.
_role_maker
.
server_index
()
*
2
)
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
self
.
_local_ip
=
self
.
_fleet_ptr
.
run_server
()
else
:
local_endpoint
=
self
.
_role_maker
.
get_local_endpoint
()
local_endpoint
=
local_endpoint
.
split
(
":"
)
self
.
_local_ip
=
self
.
_fleet_ptr
.
run_server
(
str
(
local_endpoint
[
0
]),
int
(
local_endpoint
[
1
]))
# barrier_all for init_server
self
.
_role_maker
.
_barrier_all
()
...
...
@@ -632,8 +642,8 @@ class DownpourOptimizer(DistributedOptimizer):
parameter_list
,
no_grad_set
,
self
.
_strategy
)
opt_info
[
"mpi_rank"
]
=
fleet
.
_role_maker
.
_get_rank
()
opt_info
[
"mpi_size"
]
=
fleet
.
_role_maker
.
_get_size
()
opt_info
[
"mpi_rank"
]
=
fleet
.
worker_index
()
opt_info
[
"mpi_size"
]
=
fleet
.
worker_num
()
fleet
.
_set_opt_info
(
opt_info
)
programs
=
[
loss
.
block
.
program
for
loss
in
losses
]
...
...
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
浏览文件 @
371f377b
...
...
@@ -206,7 +206,7 @@ class FleetUtil(object):
pos
=
pos
.
reshape
(
-
1
)
global_pos
=
np
.
copy
(
pos
)
*
0
# mpi allreduce
fleet
.
_role_maker
.
_
node_type_comm
.
All
reduce
(
pos
,
global_pos
)
fleet
.
_role_maker
.
_
all_
reduce
(
pos
,
global_pos
)
# reshape to its original shape
global_pos
=
global_pos
.
reshape
(
old_pos_shape
)
...
...
@@ -215,7 +215,7 @@ class FleetUtil(object):
old_neg_shape
=
np
.
array
(
neg
.
shape
)
neg
=
neg
.
reshape
(
-
1
)
global_neg
=
np
.
copy
(
neg
)
*
0
fleet
.
_role_maker
.
_
node_type_comm
.
All
reduce
(
neg
,
global_neg
)
fleet
.
_role_maker
.
_
all_
reduce
(
neg
,
global_neg
)
global_neg
=
global_neg
.
reshape
(
old_neg_shape
)
# calculate auc
...
...
@@ -1350,7 +1350,7 @@ class FleetUtil(object):
pos
=
pos
.
reshape
(
-
1
)
global_pos
=
np
.
copy
(
pos
)
*
0
# mpi allreduce
fleet
.
_role_maker
.
_
node_type_comm
.
All
reduce
(
pos
,
global_pos
)
fleet
.
_role_maker
.
_
all_
reduce
(
pos
,
global_pos
)
# reshape to its original shape
global_pos
=
global_pos
.
reshape
(
old_pos_shape
)
# auc neg bucket
...
...
@@ -1358,7 +1358,7 @@ class FleetUtil(object):
old_neg_shape
=
np
.
array
(
neg
.
shape
)
neg
=
neg
.
reshape
(
-
1
)
global_neg
=
np
.
copy
(
neg
)
*
0
fleet
.
_role_maker
.
_
node_type_comm
.
All
reduce
(
neg
,
global_neg
)
fleet
.
_role_maker
.
_
all_
reduce
(
neg
,
global_neg
)
global_neg
=
global_neg
.
reshape
(
old_neg_shape
)
num_bucket
=
len
(
global_pos
[
0
])
...
...
@@ -1368,7 +1368,7 @@ class FleetUtil(object):
old_metric_shape
=
np
.
array
(
metric
.
shape
)
metric
=
metric
.
reshape
(
-
1
)
global_metric
=
np
.
copy
(
metric
)
*
0
fleet
.
_role_maker
.
_
node_type_comm
.
All
reduce
(
metric
,
global_metric
)
fleet
.
_role_maker
.
_
all_
reduce
(
metric
,
global_metric
)
global_metric
=
global_metric
.
reshape
(
old_metric_shape
)
return
global_metric
[
0
]
...
...
python/paddle/fluid/tests/unittests/test_dataset.py
浏览文件 @
371f377b
...
...
@@ -733,7 +733,7 @@ class TestDataset2(unittest.TestCase):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
try
:
fleet
.
init
(
exe
)
fleet
.
init
()
except
ImportError
as
e
:
print
(
"warning: no mpi4py"
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.000005
)
...
...
@@ -795,7 +795,7 @@ class TestDataset2(unittest.TestCase):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
try
:
fleet
.
init
(
exe
)
fleet
.
init
()
except
ImportError
as
e
:
print
(
"warning: no mpi4py"
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.000005
)
...
...
@@ -824,6 +824,10 @@ class TestDataset2(unittest.TestCase):
dataset
.
set_pipe_command
(
"cat"
)
dataset
.
set_use_var
(
slots_vars
)
dataset
.
load_into_memory
()
try
:
dataset
.
global_shuffle
(
fleet
)
except
:
print
(
"warning: catch expected error"
)
fleet
.
_opt_info
=
None
fleet
.
_fleet_ptr
=
None
...
...
python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py
浏览文件 @
371f377b
...
...
@@ -11,36 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from
__future__
import
print_function
import
os
import
unittest
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
class
TestCloudRoleMaker
(
unittest
.
TestCase
):
"""
Test cases for PaddleCloudRoleMaker.
"""
def
setUp
(
self
):
"""Set up, set envs."""
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36001,127.0.0.2:36001"
def
test_tr_rolemaker
(
self
):
"""Test tr rolenamer."""
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
ro
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
False
)
ro
.
generate_role
()
self
.
assertTrue
(
ro
.
is_worker
())
self
.
assertFalse
(
ro
.
is_server
())
self
.
assertEqual
(
ro
.
worker_num
(),
2
)
def
test_ps_rolemaker
(
self
):
"""Test ps rolemaker."""
os
.
environ
[
"TRAINING_ROLE"
]
=
"PSERVER"
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_PORT"
]
=
"36001"
ro
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
False
)
ro
.
generate_role
()
self
.
assertFalse
(
ro
.
is_worker
())
...
...
@@ -48,10 +53,75 @@ class TestCloudRoleMaker(unittest.TestCase):
self
.
assertEqual
(
ro
.
worker_num
(),
2
)
def
test_traing_role
(
self
):
"""Test training role."""
os
.
environ
[
"TRAINING_ROLE"
]
=
"TEST"
ro
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
False
)
self
.
assertRaises
(
ValueError
,
ro
.
generate_role
)
def
test_pslib_1
(
self
):
"""Test cases for pslib."""
import
paddle.fluid
as
fluid
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
PSLib
from
paddle.fluid.incubate.fleet.base.role_maker
import
GeneralRoleMaker
try
:
import
netifaces
except
:
print
(
"warning: no netifaces, skip test_pslib_1"
)
return
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_PORT"
]
=
"36001"
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36002"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
role_maker
=
GeneralRoleMaker
()
role_maker
.
generate_role
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
fleet
.
init
(
role_maker
)
train_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
scope
=
fluid
.
Scope
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
show
=
fluid
.
layers
.
data
(
name
=
"show"
,
shape
=
[
-
1
,
1
],
\
dtype
=
"float32"
,
lod_level
=
1
,
append_batch_size
=
False
)
fc
=
fluid
.
layers
.
fc
(
input
=
show
,
size
=
1
,
act
=
None
)
label
=
fluid
.
layers
.
data
(
name
=
"click"
,
shape
=
[
-
1
,
1
],
\
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
label_cast
=
fluid
.
layers
.
cast
(
label
,
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
log_loss
(
fc
,
label_cast
)
try
:
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.000005
)
adam
=
fleet
.
distributed_optimizer
(
adam
)
adam
.
minimize
([
cost
],
[
scope
])
fleet
.
run_server
()
except
:
print
(
"do not support pslib test, skip"
)
return
from
paddle.fluid.incubate.fleet.base.role_maker
import
\
MPISymetricRoleMaker
try
:
role
=
MPISymetricRoleMaker
()
role
.
_all_reduce
([
1
],
[
2
])
except
:
print
(
"catch expected error of not inited"
)
try
:
role
=
MPISymetricRoleMaker
()
role
.
_all_reduce
([
1
],
[
2
],
"min"
)
except
:
print
(
"catch expected error of not inited"
)
try
:
role
=
MPISymetricRoleMaker
()
role
.
_all_reduce
([
1
],
[
2
],
"max"
)
except
:
print
(
"catch expected error of not inited"
)
try
:
role
=
MPISymetricRoleMaker
()
role
.
_all_reduce
([
1
],
[
2
],
"unknown"
)
except
:
print
(
"catch expected error of unknown type"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py
0 → 100644
浏览文件 @
371f377b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cases for role makers."""
from
__future__
import
print_function
import
os
import
unittest
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
class
TestCloudRoleMaker2
(
unittest
.
TestCase
):
"""
Test cases for paddle cloud role makers.
"""
def
setUp
(
self
):
"""Set up, set envs."""
pass
def
test_pslib_2
(
self
):
"""Test cases for pslib."""
import
paddle.fluid
as
fluid
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
PSLib
from
paddle.fluid.incubate.fleet.base.role_maker
import
GeneralRoleMaker
from
paddle.fluid.incubate.fleet.base.role_maker
import
RoleMakerBase
try
:
import
netifaces
except
:
print
(
"warning: no netifaces, skip test_pslib_2"
)
return
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_PORT"
]
=
"36001"
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36002"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"1"
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
try
:
fleet
.
init
(
None
)
except
:
print
(
"no mpi4py, skip test_pslib_2"
)
return
train_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
scope
=
fluid
.
Scope
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
show
=
fluid
.
layers
.
data
(
name
=
"show"
,
shape
=
[
-
1
,
1
],
\
dtype
=
"float32"
,
lod_level
=
1
,
append_batch_size
=
False
)
fc
=
fluid
.
layers
.
fc
(
input
=
show
,
size
=
1
,
act
=
None
)
label
=
fluid
.
layers
.
data
(
name
=
"click"
,
shape
=
[
-
1
,
1
],
\
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
label_cast
=
fluid
.
layers
.
cast
(
label
,
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
log_loss
(
fc
,
label_cast
)
try
:
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.000005
)
adam
=
fleet
.
distributed_optimizer
(
adam
)
adam
.
minimize
([
cost
],
[
scope
])
fleet
.
run_server
()
except
:
print
(
"do not support pslib test, skip"
)
return
os
.
environ
[
"TRAINING_ROLE"
]
=
"wrong"
try
:
role1
=
GeneralRoleMaker
(
path
=
"./test_gloo_1"
)
role1
.
generate_role
()
except
:
print
(
"catch expected error of wrong TRAINING_ROLE"
)
os
.
environ
[
"TRAINING_ROLE"
]
=
"PSERVER"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36001"
role2
=
GeneralRoleMaker
(
path
=
"./test_gloo_2"
)
role2
.
_finalize
()
role2
.
_all_gather
(
1
)
role2
.
_all_gather
(
1
)
role2
.
_barrier_server
()
role2
.
all_gather
(
1
)
role3
=
GeneralRoleMaker
(
path
=
"./test_gloo_3"
)
role3
.
_worker_gather
(
1
)
role3
.
_worker_gather
(
1
)
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36002"
role4
=
GeneralRoleMaker
(
path
=
"./test_gloo_4"
)
role4
.
_worker_gather
(
1
)
role4
.
_get_rank
()
role4
.
_get_size
()
role4
.
_all_comm
.
init
(
0
,
0
,
""
,
""
,
""
,
""
,
""
)
role5
=
GeneralRoleMaker
(
path
=
"./test_gloo_5"
)
role5
.
get_local_endpoint
()
role5
.
get_local_endpoint
()
role6
=
GeneralRoleMaker
(
path
=
"./test_gloo_6"
)
role6
.
get_trainer_endpoints
()
role6
.
get_trainer_endpoints
()
role7
=
GeneralRoleMaker
(
path
=
"./test_gloo_7"
)
role7
.
get_pserver_endpoints
()
role7
.
get_pserver_endpoints
()
role8
=
GeneralRoleMaker
(
path
=
"./test_gloo_8"
)
role8
.
is_worker
()
role8
.
is_worker
()
role9
=
GeneralRoleMaker
(
path
=
"./test_gloo_9"
)
role9
.
is_server
()
role9
.
is_server
()
role10
=
GeneralRoleMaker
(
path
=
"./test_gloo_10"
)
role10
.
is_first_worker
()
role10
.
is_first_worker
()
role11
=
GeneralRoleMaker
(
path
=
"./test_gloo_11"
)
role11
.
worker_index
()
role11
.
worker_index
()
role12
=
GeneralRoleMaker
(
path
=
"./test_gloo_12"
)
role12
.
server_index
()
role12
.
server_index
()
role13
=
GeneralRoleMaker
(
path
=
"./test_gloo_13"
)
role13
.
worker_num
()
role13
.
worker_num
()
role14
=
GeneralRoleMaker
(
path
=
"./test_gloo_14"
)
role14
.
server_num
()
role14
.
server_num
()
role15
=
GeneralRoleMaker
(
path
=
"./test_gloo_15"
)
role15
.
_barrier_worker
()
role15
.
_barrier_worker
()
role16
=
GeneralRoleMaker
(
path
=
"./test_gloo_16"
)
role16
.
_barrier_all
()
role16
.
_barrier_all
()
role17
=
GeneralRoleMaker
(
path
=
"./test_gloo_17"
)
role17
.
_barrier_server
()
role17
.
_barrier_server
()
role18
=
GeneralRoleMaker
(
path
=
"./test_gloo_18"
)
role18
.
_worker_num
()
role18
.
_worker_num
()
role19
=
GeneralRoleMaker
(
path
=
"./test_gloo_19"
)
role19
.
_server_num
()
role19
.
_server_num
()
role20
=
GeneralRoleMaker
(
path
=
"./test_gloo_20"
)
a
=
[
1
]
b
=
[
0
]
role20
.
_all_reduce
(
a
,
b
)
role21
=
GeneralRoleMaker
(
path
=
"./test_gloo_21"
)
role21
.
all_reduce_worker
([],
[])
role21
.
all_reduce_worker
([],
[])
role21
.
barrier_worker
()
role21
.
barrier_all
()
role22
=
GeneralRoleMaker
(
path
=
"./test_gloo_22"
)
role22
.
_get_rank
()
role22
.
_get_rank
()
os
.
environ
[
"PADDLE_PSERVER_ID"
]
=
"0"
role23
=
GeneralRoleMaker
(
path
=
"./test_gloo_23"
)
role23
.
_get_size
()
role23
.
_get_size
()
with
open
(
"test_fleet_gloo_role_maker_1.txt"
,
"w"
)
as
f
:
data
=
"1 1 1 1
\n
"
f
.
write
(
data
)
dataset
=
fluid
.
DatasetFactory
().
create_dataset
(
"InMemoryDataset"
)
dataset
.
set_filelist
([
"test_fleet_gloo_role_maker_1.txt"
])
dataset
.
set_use_var
([
show
,
label
])
dataset
.
load_into_memory
()
dataset
.
get_memory_data_size
(
fleet
)
dataset
.
get_shuffle_data_size
(
fleet
)
os
.
remove
(
"./test_fleet_gloo_role_maker_1.txt"
)
class
TmpClass
():
"""
dummy tmp class
"""
def
__init__
(
self
):
pass
def
all_reduce_worker
(
self
,
input
,
output
):
"""
dummy all reduce worker
Args:
input(None): fake input
output(None): fale output
"""
pass
def
barrier_worker
(
self
):
"""
dummy barrier worker
"""
pass
from
paddle.fluid.incubate.fleet.base.fleet_base
import
Fleet
class
TmpFleet
(
Fleet
):
"""
dummy tmp fleet
"""
def
__init__
(
self
):
super
(
Fleet
,
self
).
__init__
()
self
.
_role_maker
=
None
def
init_worker
(
self
):
"""
dummy init worker
"""
pass
def
init_server
(
self
,
model_dir
=
None
):
"""
dummy init server
Args:
model_dir(None): fake model_dir
"""
pass
def
run_server
(
self
):
"""
dummy run server
"""
pass
def
stop_worker
(
self
):
"""
dummy stop worker
"""
pass
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
"""
dummy distributed optimizer
Args:
optimizer(None): fake optimizer
strategy(None): fake strategy
"""
pass
def
save_inference_model
(
self
):
"""
dummy save inference model
"""
pass
def
save_persistables
(
self
):
"""
dummy save persistables
"""
pass
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
tmp
=
TmpFleet
()
tmp
.
_role_maker
=
TmpClass
()
tmp
.
all_reduce_worker
([],
[])
tmp
.
barrier_worker
()
from
paddle.fluid.incubate.fleet.base.role_maker
import
GeneralRoleMaker
tmp
=
RoleMakerBase
()
tmp
.
all_gather
(
1
)
tmp
.
all_reduce_worker
([],
[])
tmp
.
barrier_worker
()
tmp
.
barrier_all
()
from
paddle.fluid.incubate.fleet.base.role_maker
import
\
MPISymetricRoleMaker
tmp1
=
MPISymetricRoleMaker
()
tmp1
.
all_gather
(
1
)
tmp1
.
all_gather
(
1
)
tmp2
=
MPISymetricRoleMaker
()
tmp2
.
all_reduce_worker
([],
[])
tmp3
=
MPISymetricRoleMaker
()
tmp3
.
barrier_worker
()
tmp3
.
barrier_worker
()
tmp4
=
MPISymetricRoleMaker
()
tmp4
.
barrier_all
()
tmp4
.
barrier_all
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录