Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b95cd3b7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b95cd3b7
编写于
2月 22, 2022
作者:
L
lilong12
提交者:
GitHub
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the implementation of TCP Store (#39384)
* add tcp_socket and tcp_store
上级
da43e065
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
877 addition
and
1 deletion
+877
-1
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+2
-0
paddle/fluid/distributed/store/CMakeLists.txt
paddle/fluid/distributed/store/CMakeLists.txt
+1
-0
paddle/fluid/distributed/store/store.h
paddle/fluid/distributed/store/store.h
+43
-0
paddle/fluid/distributed/store/tcp_store.cc
paddle/fluid/distributed/store/tcp_store.cc
+272
-0
paddle/fluid/distributed/store/tcp_store.h
paddle/fluid/distributed/store/tcp_store.h
+114
-0
paddle/fluid/distributed/store/tcp_utils.cc
paddle/fluid/distributed/store/tcp_utils.cc
+201
-0
paddle/fluid/distributed/store/tcp_utils.h
paddle/fluid/distributed/store/tcp_utils.h
+133
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+2
-1
paddle/fluid/pybind/communication.cc
paddle/fluid/pybind/communication.cc
+42
-0
paddle/fluid/pybind/communication.h
paddle/fluid/pybind/communication.h
+31
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-0
python/paddle/fluid/tests/unittests/test_tcp_store.py
python/paddle/fluid/tests/unittests/test_tcp_store.py
+34
-0
未找到文件。
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
b95cd3b7
add_subdirectory
(
store
)
if
(
NOT WITH_PSCORE
)
if
(
NOT WITH_PSCORE
)
add_subdirectory
(
fleet_executor
)
add_subdirectory
(
fleet_executor
)
return
()
return
()
...
...
paddle/fluid/distributed/store/CMakeLists.txt
0 → 100644
浏览文件 @
b95cd3b7
cc_library
(
tcp_store SRCS tcp_store.cc tcp_utils.cc DEPS enforce glog
)
paddle/fluid/distributed/store/store.h
0 → 100644
浏览文件 @
b95cd3b7
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/store/tcp_utils.h"
namespace
paddle
{
namespace
distributed
{
class
Store
{
public:
Store
()
=
delete
;
explicit
Store
(
const
std
::
chrono
::
seconds
&
timeout
)
:
_timeout
(
timeout
)
{}
virtual
~
Store
()
=
default
;
virtual
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
=
0
;
virtual
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
=
0
;
virtual
void
wait
(
const
std
::
string
&
key
)
=
0
;
virtual
const
std
::
chrono
::
seconds
&
timeout
()
const
{
return
_timeout
;
}
private:
std
::
chrono
::
seconds
_timeout
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/store/tcp_store.cc
0 → 100644
浏览文件 @
b95cd3b7
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono>
#include <iostream>
#include <thread>
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
namespace
detail
{
constexpr
int
INFTIME
=
-
1
;
std
::
unique_ptr
<
MasterDaemon
>
MasterDaemon
::
start
(
SocketType
socket
)
{
return
std
::
make_unique
<
MasterDaemon
>
(
socket
);
}
MasterDaemon
::
MasterDaemon
(
SocketType
socket
)
:
_listen_socket
(
socket
)
{
_background_thread
=
std
::
thread
{
&
MasterDaemon
::
run
,
this
};
}
MasterDaemon
::~
MasterDaemon
()
{
_background_thread
.
join
();
tcputils
::
close_socket
(
_listen_socket
);
for
(
SocketType
socket
:
_sockets
)
{
tcputils
::
close_socket
(
socket
);
}
}
void
MasterDaemon
::
_do_add
(
SocketType
socket
)
{
int64_t
new_value
{};
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
new_value
=
tcputils
::
receive_value
<
int64_t
>
(
socket
);
std
::
vector
<
uint8_t
>
old_value
;
auto
it
=
_store
.
find
(
key
);
if
(
it
!=
_store
.
end
())
{
old_value
=
it
->
second
;
char
*
buffer
=
reinterpret_cast
<
char
*>
(
it
->
second
.
data
());
size_t
len
=
old_value
.
size
();
new_value
+=
std
::
stoll
(
std
::
string
(
buffer
,
len
));
}
std
::
string
new_value_str
=
std
::
to_string
(
new_value
);
_store
[
key
]
=
std
::
vector
<
uint8_t
>
(
new_value_str
.
begin
(),
new_value_str
.
end
());
VLOG
(
3
)
<<
"TCPStore: new value ("
<<
new_value
<<
") for key ("
<<
key
<<
")."
;
tcputils
::
send_value
<
int64_t
>
(
socket
,
new_value
);
}
void
MasterDaemon
::
_do_get
(
SocketType
socket
)
{
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
auto
iter
=
_store
.
find
(
key
);
PADDLE_ENFORCE_NE
(
iter
,
_store
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Key %s not found in TCPStore."
,
key
));
std
::
vector
<
uint8_t
>
value
=
iter
->
second
;
VLOG
(
3
)
<<
"TCPStore: value ("
<<
std
::
stoll
(
std
::
string
(
reinterpret_cast
<
char
*>
(
value
.
data
()),
value
.
size
()))
<<
") for key ("
<<
key
<<
")."
;
tcputils
::
send_vector
<
uint8_t
>
(
socket
,
value
);
}
void
MasterDaemon
::
_do_stop
(
SocketType
socket
)
{
ReplyType
value
=
ReplyType
::
STOP_WAIT
;
_stop
=
true
;
tcputils
::
send_value
<
ReplyType
>
(
socket
,
value
);
}
void
MasterDaemon
::
_do_wait
(
SocketType
socket
)
{
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
auto
iter
=
_store
.
find
(
key
);
auto
reply
=
ReplyType
::
STOP_WAIT
;
if
(
iter
==
_store
.
end
())
{
reply
=
ReplyType
::
WAITING
;
}
VLOG
(
3
)
<<
"TCPStore: wait reply ("
<<
static_cast
<
int
>
(
reply
)
<<
") for key ("
<<
key
<<
")."
;
tcputils
::
send_value
<
ReplyType
>
(
socket
,
reply
);
}
void
MasterDaemon
::
run
()
{
std
::
vector
<
struct
pollfd
>
fds
;
#ifdef _WIN32
fds
.
push_back
({
_listen_socket
,
POLLIN
});
#else
fds
.
push_back
({.
fd
=
_listen_socket
,
.
events
=
POLLIN
,
.
revents
=
0
});
#endif
while
(
!
_stop
)
{
for
(
size_t
i
=
0
;
i
<
fds
.
size
();
i
++
)
{
fds
[
i
].
revents
=
0
;
}
#ifdef _WIN32
::
WSAPoll
(
fds
.
data
(),
fds
.
size
(),
INFTIME
);
#else
::
poll
(
fds
.
data
(),
fds
.
size
(),
INFTIME
);
#endif
if
(
fds
[
0
].
revents
!=
0
)
{
auto
socket
=
tcputils
::
tcp_accept
(
_listen_socket
);
_sockets
.
emplace_back
(
socket
);
#ifdef _WIN32
fds
.
push_back
({
socket
,
POLLIN
});
#else
fds
.
push_back
({.
fd
=
socket
,
.
events
=
POLLIN
,
.
revents
=
0
});
#endif
}
for
(
size_t
i
=
1
;
i
<
fds
.
size
();
i
++
)
{
if
(
fds
[
i
].
revents
==
0
)
{
continue
;
}
Command
command
=
tcputils
::
receive_value
<
Command
>
(
fds
[
i
].
fd
);
VLOG
(
3
)
<<
"TCPStore: recv command: "
<<
static_cast
<
int
>
(
command
)
<<
"."
;
switch
(
command
)
{
case
Command
::
ADD
:
_do_add
(
fds
[
i
].
fd
);
break
;
case
Command
::
GET
:
_do_get
(
fds
[
i
].
fd
);
break
;
case
Command
::
WAIT
:
_do_wait
(
fds
[
i
].
fd
);
break
;
case
Command
::
STOP
:
_do_stop
(
fds
[
i
].
fd
);
break
;
}
}
}
}
std
::
unique_ptr
<
TCPServer
>
TCPServer
::
create
(
uint16_t
port
)
{
int
socket
=
tcputils
::
tcp_listen
(
""
,
std
::
to_string
(
port
),
AF_INET
);
auto
server
=
std
::
make_unique
<
TCPServer
>
();
server
->
_master_daemon
=
MasterDaemon
::
start
(
socket
);
return
server
;
}
std
::
unique_ptr
<
TCPClient
>
TCPClient
::
connect
(
const
std
::
string
host
,
uint16_t
port
)
{
int
socket
=
tcputils
::
tcp_connect
(
host
,
std
::
to_string
(
port
),
AF_INET
);
return
std
::
make_unique
<
TCPClient
>
(
socket
);
}
void
TCPClient
::
send_command_for_key
(
Command
type
,
const
std
::
string
&
key
)
{
tcputils
::
send_value
<
Command
>
(
_socket
,
type
);
if
(
key
.
empty
())
{
return
;
}
tcputils
::
send_string
(
_socket
,
key
);
}
template
<
typename
T
>
void
TCPClient
::
send_value
(
const
T
&
value
)
{
tcputils
::
send_bytes
<
T
>
(
_socket
,
&
value
,
1
);
}
template
<
typename
T
>
T
TCPClient
::
receive_value
()
{
T
res
;
tcputils
::
receive_bytes
<
T
>
(
_socket
,
&
res
,
1
);
return
res
;
}
template
<
typename
T
>
void
TCPClient
::
send_vector
(
const
std
::
vector
<
T
>&
value
)
{
tcputils
::
send_vector
<
T
>
(
_socket
,
value
);
}
template
<
typename
T
>
std
::
vector
<
T
>
TCPClient
::
receive_vector
()
{
return
tcputils
::
receive_vector
<
T
>
(
_socket
);
}
}
// namespace detail
TCPStore
::
TCPStore
(
std
::
string
host
,
uint16_t
port
,
bool
is_master
,
size_t
num_workers
,
std
::
chrono
::
seconds
timeout
)
:
Store
(
timeout
),
_is_master
(
is_master
),
_num_workers
(
num_workers
)
{
if
(
_is_master
)
{
_server
=
detail
::
TCPServer
::
create
(
port
);
}
_client
=
detail
::
TCPClient
::
connect
(
host
,
port
);
waitWorkers
();
}
void
TCPStore
::
waitWorkers
()
{
if
(
_num_workers
==
0
)
{
return
;
}
add
(
_init_key
,
1
);
if
(
_server
)
{
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
do
{
auto
value
=
get
(
_init_key
);
int
completed
=
std
::
stoi
(
std
::
string
(
value
.
begin
(),
value
.
end
()));
VLOG
(
3
)
<<
completed
<<
" worker ready, total "
<<
_num_workers
;
if
(
completed
>=
_num_workers
)
{
break
;
}
const
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
seconds
>
(
std
::
chrono
::
steady_clock
::
now
()
-
begin
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
if
(
_timeout
!=
tcputils
::
kNoTimeout
&&
elapsed
>
_timeout
)
{
PADDLE_ENFORCE_EQ
(
completed
,
_num_workers
,
platform
::
errors
::
InvalidArgument
(
"TCPStore timeouted and not all workers got ready."
));
}
}
while
(
true
);
}
VLOG
(
3
)
<<
"TCPStore initialized."
;
}
int64_t
TCPStore
::
add
(
const
std
::
string
&
key
,
int64_t
value
)
{
_client
->
send_command_for_key
(
Command
::
ADD
,
_key_prefix
+
key
);
_client
->
send_value
<
std
::
int64_t
>
(
value
);
return
_client
->
receive_value
<
std
::
int64_t
>
();
}
std
::
vector
<
uint8_t
>
TCPStore
::
get
(
const
std
::
string
&
key
)
{
wait
(
key
);
_client
->
send_command_for_key
(
Command
::
GET
,
_key_prefix
+
key
);
VLOG
(
3
)
<<
"TCPStore get."
;
return
_client
->
receive_vector
<
uint8_t
>
();
}
void
TCPStore
::
wait
(
const
std
::
string
&
key
)
{
ReplyType
reply
;
do
{
_client
->
send_command_for_key
(
Command
::
WAIT
,
_key_prefix
+
key
);
reply
=
_client
->
receive_value
<
ReplyType
>
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
500
));
}
while
(
reply
!=
ReplyType
::
STOP_WAIT
);
}
TCPStore
::~
TCPStore
()
{
_client
->
send_command_for_key
(
Command
::
STOP
,
""
);
ReplyType
ret
=
_client
->
receive_value
<
ReplyType
>
();
PADDLE_ENFORCE_EQ
(
ret
,
ReplyType
::
STOP_WAIT
,
platform
::
errors
::
InvalidArgument
(
"The reply for TCPStore destructure must be 0."
));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/store/tcp_store.h
0 → 100644
浏览文件 @
b95cd3b7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
#include <unordered_map>
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
namespace
paddle
{
namespace
distributed
{
enum
class
ReplyType
{
WAITING
,
STOP_WAIT
};
enum
class
Command
{
ADD
,
GET
,
WAIT
,
STOP
};
namespace
detail
{
class
MasterDaemon
{
public:
static
std
::
unique_ptr
<
MasterDaemon
>
start
(
SocketType
listen_socket
);
MasterDaemon
()
=
delete
;
explicit
MasterDaemon
(
SocketType
listen_socket
);
~
MasterDaemon
();
private:
void
run
();
void
_do_add
(
SocketType
socket
);
void
_do_wait
(
SocketType
socket
);
void
_do_get
(
SocketType
socket
);
void
_do_stop
(
SocketType
socket
);
SocketType
_listen_socket
;
std
::
vector
<
SocketType
>
_sockets
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
_store
;
std
::
thread
_background_thread
{};
bool
_stop
=
false
;
};
class
TCPServer
{
public:
TCPServer
()
=
default
;
static
std
::
unique_ptr
<
TCPServer
>
create
(
std
::
uint16_t
port
);
private:
std
::
unique_ptr
<
MasterDaemon
>
_master_daemon
;
};
class
TCPClient
{
public:
explicit
TCPClient
(
SocketType
socket
)
:
_socket
{
socket
}
{}
static
std
::
unique_ptr
<
TCPClient
>
connect
(
const
std
::
string
host
,
uint16_t
port
);
~
TCPClient
()
{
tcputils
::
close_socket
(
_socket
);
}
void
send_command_for_key
(
Command
type
,
const
std
::
string
&
key
);
template
<
typename
T
>
void
send_value
(
const
T
&
value
);
template
<
typename
T
>
void
send_vector
(
const
std
::
vector
<
T
>&
value
);
template
<
typename
T
>
std
::
vector
<
T
>
receive_vector
();
template
<
typename
T
>
T
receive_value
();
private:
SocketType
_socket
;
};
}
// namespace detail
class
TCPStore
:
public
Store
{
public:
static
constexpr
std
::
uint16_t
kDefaultPort
=
6170
;
explicit
TCPStore
(
std
::
string
host
,
uint16_t
port
=
kDefaultPort
,
bool
is_master
=
false
,
size_t
num_workers
=
1
,
std
::
chrono
::
seconds
timeout
=
tcputils
::
kDefaultTimeout
);
~
TCPStore
();
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
override
;
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
override
;
void
wait
(
const
std
::
string
&
key
)
override
;
private:
void
waitWorkers
();
std
::
unique_ptr
<
detail
::
TCPServer
>
_server
;
std
::
unique_ptr
<
detail
::
TCPClient
>
_client
;
const
std
::
string
_init_key
=
"init/"
;
const
std
::
string
_key_prefix
=
"/"
;
std
::
chrono
::
seconds
_timeout
;
bool
_is_master
;
int
_num_workers
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/store/tcp_utils.cc
0 → 100644
浏览文件 @
b95cd3b7
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include <cerrno>
#include <cstring>
#include <thread>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
namespace
tcputils
{
std
::
error_code
socket_error
()
{
#ifdef _WIN32
return
std
::
error_code
{
::
WSAGetLastError
(),
std
::
generic_category
()};
#else
return
std
::
error_code
{
errno
,
std
::
generic_category
()};
#endif
}
void
close_socket
(
SocketType
socket
)
{
#ifdef _WIN32
::
closesocket
(
socket
);
#else
::
close
(
socket
);
#endif
}
::
addrinfo
*
get_addr_info
(
const
std
::
string
host
,
const
std
::
string
port
,
int
ai_flags
,
int
family
)
{
::
addrinfo
hints
{},
*
res
;
hints
.
ai_flags
=
ai_flags
;
hints
.
ai_family
=
family
;
hints
.
ai_socktype
=
SOCK_STREAM
;
const
char
*
node
=
host
.
empty
()
?
nullptr
:
host
.
c_str
();
int
n
;
n
=
::
getaddrinfo
(
node
,
port
.
c_str
(),
&
hints
,
&
res
);
const
char
*
gai_err
=
::
gai_strerror
(
n
);
const
char
*
proto
=
(
family
==
AF_INET
?
"IPv4"
:
family
==
AF_INET6
?
"IPv6"
:
""
);
PADDLE_ENFORCE_EQ
(
n
,
0
,
platform
::
errors
::
InvalidArgument
(
"%s network %s:%s cannot be obtained. Details: %s."
,
proto
,
host
,
port
,
gai_err
));
return
res
;
}
void
free_addr_info
(
::
addrinfo
*
hint
)
{
PADDLE_ENFORCE_NOT_NULL
(
hint
,
platform
::
errors
::
InvalidArgument
(
"The parameter for free_addr_info cannot be null."
));
::
freeaddrinfo
(
hint
);
}
SocketType
tcp_connect
(
const
std
::
string
host
,
const
std
::
string
port
,
int
family
,
std
::
chrono
::
seconds
timeout
)
{
int
ai_flags
=
AI_NUMERICSERV
|
AI_V4MAPPED
|
AI_ALL
;
::
addrinfo
*
res
=
get_addr_info
(
host
,
port
,
ai_flags
,
family
);
SocketType
sockfd
=
-
1
;
bool
retry
=
true
;
auto
deadline
=
std
::
chrono
::
steady_clock
::
now
()
+
timeout
;
do
{
for
(
::
addrinfo
*
cur
=
res
;
cur
!=
nullptr
;
cur
=
cur
->
ai_next
)
{
sockfd
=
::
socket
(
cur
->
ai_family
,
cur
->
ai_socktype
,
cur
->
ai_protocol
);
PADDLE_ENFORCE_GT
(
sockfd
,
0
,
platform
::
errors
::
InvalidArgument
(
"Create socket to connect %s:%s failed. "
"Details: %s. "
,
host
,
port
,
socket_error
().
message
()));
if
(
::
connect
(
sockfd
,
cur
->
ai_addr
,
cur
->
ai_addrlen
)
==
0
)
{
retry
=
false
;
break
;
}
VLOG
(
0
)
<<
"Retry to connect to "
<<
host
<<
":"
<<
port
<<
" while the server is not yet listening."
;
close_socket
(
sockfd
);
sockfd
=
-
1
;
std
::
this_thread
::
sleep_for
(
kDelay
);
if
(
timeout
!=
kNoTimeout
&&
std
::
chrono
::
steady_clock
::
now
()
>=
deadline
)
{
retry
=
false
;
break
;
}
}
if
(
timeout
!=
kNoTimeout
&&
std
::
chrono
::
steady_clock
::
now
()
>=
deadline
)
{
retry
=
false
;
}
}
while
(
retry
);
free_addr_info
(
res
);
PADDLE_ENFORCE_GT
(
sockfd
,
0
,
platform
::
errors
::
InvalidArgument
(
"Network %s:%s cannot be connected."
,
host
,
port
));
VLOG
(
0
)
<<
"Successfully connected to "
<<
host
<<
":"
<<
port
;
return
sockfd
;
}
SocketType
tcp_listen
(
const
std
::
string
host
,
const
std
::
string
port
,
int
family
)
{
int
ai_flags
=
AI_PASSIVE
|
AI_NUMERICSERV
;
::
addrinfo
*
res
=
get_addr_info
(
host
,
port
,
ai_flags
,
family
);
::
addrinfo
*
cur
=
res
;
SocketType
sockfd
{};
std
::
string
node
=
host
.
empty
()
?
"IP_ANY"
:
host
;
while
(
cur
)
{
sockfd
=
::
socket
(
cur
->
ai_family
,
cur
->
ai_socktype
,
cur
->
ai_protocol
);
if
(
sockfd
<
0
)
{
VLOG
(
0
)
<<
"Cannot create socket on "
<<
node
<<
":"
<<
port
<<
". Details: "
<<
socket_error
().
message
();
cur
=
cur
->
ai_next
;
continue
;
}
int
on
=
1
;
#ifdef _WIN32
int
ret
=
::
setsockopt
(
sockfd
,
SOL_SOCKET
,
SO_REUSEADDR
,
reinterpret_cast
<
char
*>
(
&
on
),
sizeof
(
on
));
#else
int
ret
=
::
setsockopt
(
sockfd
,
SOL_SOCKET
,
SO_REUSEADDR
,
&
on
,
sizeof
(
on
));
#endif
if
(
ret
<
0
)
{
VLOG
(
0
)
<<
"Set the address reuse option failed on the server."
;
}
if
(
::
bind
(
sockfd
,
res
->
ai_addr
,
res
->
ai_addrlen
)
==
0
)
{
break
;
}
close_socket
(
sockfd
);
sockfd
=
-
1
;
cur
=
cur
->
ai_next
;
}
PADDLE_ENFORCE_GT
(
sockfd
,
0
,
platform
::
errors
::
InvalidArgument
(
"Bind network on %s:%s failedd."
,
node
,
port
));
::
listen
(
sockfd
,
LISTENQ
);
VLOG
(
0
)
<<
"The server starts to listen on "
<<
node
<<
":"
<<
port
;
return
sockfd
;
}
SocketType
tcp_accept
(
SocketType
socket
)
{
::
sockaddr_storage
addr_s
{};
::
socklen_t
addr_len
=
sizeof
(
addr_s
);
SocketType
new_socket
=
::
accept
(
socket
,
reinterpret_cast
<::
sockaddr
*>
(
&
addr_s
),
&
addr_len
);
PADDLE_ENFORCE_GT
(
new_socket
,
0
,
platform
::
errors
::
InvalidArgument
(
"The server failed to accept a new connection. Details: %s."
,
socket_error
().
message
()));
#ifndef _WIN32
::
fcntl
(
new_socket
,
F_SETFD
,
FD_CLOEXEC
);
#endif
auto
value
=
1
;
#ifdef _WIN32
::
setsockopt
(
new_socket
,
IPPROTO_TCP
,
TCP_NODELAY
,
reinterpret_cast
<
const
char
*>
(
&
value
),
sizeof
(
value
));
#else
::
setsockopt
(
new_socket
,
IPPROTO_TCP
,
TCP_NODELAY
,
&
value
,
sizeof
(
value
));
#endif
return
new_socket
;
}
void
send_string
(
SocketType
socket
,
const
std
::
string
&
s
)
{
std
::
string
::
size_type
size
=
s
.
size
();
send_bytes
<
std
::
string
::
size_type
>
(
socket
,
&
size
,
1
);
send_bytes
<
const
char
>
(
socket
,
s
.
data
(),
size
);
}
std
::
string
receive_string
(
SocketType
socket
)
{
std
::
string
::
size_type
size
;
receive_bytes
<
std
::
string
::
size_type
>
(
socket
,
&
size
,
1
);
std
::
vector
<
char
>
v
(
size
);
receive_bytes
<
char
>
(
socket
,
v
.
data
(),
size
);
return
std
::
string
(
v
.
data
(),
v
.
size
());
}
}
// namespace tcputils
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/store/tcp_utils.h
0 → 100644
浏览文件 @
b95cd3b7
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#else
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <chrono>
#include <iostream>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
// Utility functions for TCP socket.
namespace
paddle
{
namespace
distributed
{
#ifdef _WIN32
using
SocketType
=
SOCKET
;
#else
using
SocketType
=
int
;
#endif
namespace
tcputils
{
constexpr
int
LISTENQ
=
2048
;
constexpr
std
::
chrono
::
seconds
kDelay
=
std
::
chrono
::
seconds
(
3
);
constexpr
std
::
chrono
::
seconds
kNoTimeout
=
std
::
chrono
::
seconds
::
zero
();
constexpr
std
::
chrono
::
seconds
kDefaultTimeout
=
std
::
chrono
::
seconds
(
360
);
std
::
error_code
socket_error
();
void
close_socket
(
SocketType
socket
);
::
addrinfo
*
get_addr_info
(
const
std
::
string
host
,
const
std
::
string
port
,
int
ai_flags
,
int
family
);
void
free_addr_info
(
::
addrinfo
*
);
SocketType
tcp_connect
(
const
std
::
string
host
,
const
std
::
string
port
,
int
family
,
std
::
chrono
::
seconds
timeout
=
kNoTimeout
);
SocketType
tcp_listen
(
const
std
::
string
host
,
const
std
::
string
port
,
int
family
);
SocketType
tcp_accept
(
SocketType
socket
);
void
send_string
(
SocketType
socket
,
const
std
::
string
&
s
);
std
::
string
receive_string
(
SocketType
socket
);
template
<
typename
T
>
void
send_bytes
(
SocketType
socket
,
const
T
*
buffer
,
size_t
len
)
{
size_t
to_send
=
len
*
sizeof
(
T
);
if
(
to_send
==
0
)
{
return
;
}
auto
ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
while
(
to_send
>
0
)
{
auto
byte_sent
=
::
send
(
socket
,
ptr
,
to_send
,
0
);
PADDLE_ENFORCE_GT
(
byte_sent
,
0
,
platform
::
errors
::
InvalidArgument
(
"TCP send error. Details: %s."
,
socket_error
().
message
()));
to_send
-=
byte_sent
;
ptr
+=
byte_sent
;
}
}
template
<
typename
T
>
void
receive_bytes
(
SocketType
socket
,
T
*
buffer
,
size_t
len
)
{
size_t
to_recv
=
len
*
sizeof
(
T
);
if
(
to_recv
==
0
)
{
return
;
}
auto
ptr
=
reinterpret_cast
<
char
*>
(
buffer
);
while
(
to_recv
>
0
)
{
auto
byte_received
=
::
recv
(
socket
,
ptr
,
to_recv
,
0
);
PADDLE_ENFORCE_GT
(
byte_received
,
0
,
platform
::
errors
::
InvalidArgument
(
"TCP receive error. Details: %s."
,
socket_error
().
message
()));
to_recv
-=
byte_received
;
ptr
+=
byte_received
;
}
}
template
<
typename
T
>
void
send_vector
(
SocketType
socket
,
const
std
::
vector
<
T
>&
v
)
{
size_t
size
=
v
.
size
();
send_bytes
<
size_t
>
(
socket
,
&
size
,
1
);
send_bytes
<
T
>
(
socket
,
v
.
data
(),
size
);
}
template
<
typename
T
>
std
::
vector
<
T
>
receive_vector
(
SocketType
socket
)
{
size_t
size
;
receive_bytes
<
size_t
>
(
socket
,
&
size
,
1
);
std
::
vector
<
T
>
res
(
size
);
receive_bytes
<
T
>
(
socket
,
res
.
data
(),
size
);
return
res
;
}
template
<
typename
T
>
void
send_value
(
SocketType
socket
,
const
T
&
v
)
{
send_bytes
<
T
>
(
socket
,
&
v
,
1
);
}
template
<
typename
T
>
T
receive_value
(
SocketType
socket
)
{
T
v
;
receive_bytes
<
T
>
(
socket
,
&
v
,
1
);
return
v
;
}
}
// namespace tcputils
}
// namespace distributed
}
// namespace paddle
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
b95cd3b7
...
@@ -2,7 +2,7 @@ set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_
...
@@ -2,7 +2,7 @@ set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils
)
cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils
tcp_store
)
if
(
WITH_PSCORE
)
if
(
WITH_PSCORE
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
ps_service
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
ps_service
)
...
@@ -73,6 +73,7 @@ set(PYBIND_SRCS
...
@@ -73,6 +73,7 @@ set(PYBIND_SRCS
compatible.cc
compatible.cc
io.cc
io.cc
generator_py.cc
generator_py.cc
communication.cc
cuda_streams_py.cc
)
cuda_streams_py.cc
)
if
(
WITH_ASCEND
)
if
(
WITH_ASCEND
)
...
...
paddle/fluid/pybind/communication.cc
0 → 100644
浏览文件 @
b95cd3b7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <chrono>
#include <string>
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/pybind/communication.h"
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
using
TCPStore
=
paddle
::
distributed
::
TCPStore
;
void
BindTCPStore
(
py
::
module
*
m
)
{
py
::
class_
<
TCPStore
>
(
*
m
,
"TCPStore"
)
.
def
(
py
::
init
<
std
::
string
,
uint16_t
,
bool
,
size_t
,
std
::
chrono
::
seconds
>
())
.
def
(
"add"
,
&
TCPStore
::
add
)
.
def
(
"get"
,
&
TCPStore
::
get
);
}
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/communication.h
0 → 100644
浏览文件 @
b95cd3b7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/chrono.h"
#include "pybind11/complex.h"
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
paddle
{
namespace
pybind
{
void
BindTCPStore
(
pybind11
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
b95cd3b7
...
@@ -91,6 +91,7 @@ limitations under the License. */
...
@@ -91,6 +91,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/bind_cost_model.h"
#include "paddle/fluid/pybind/bind_cost_model.h"
#include "paddle/fluid/pybind/bind_fleet_executor.h"
#include "paddle/fluid/pybind/bind_fleet_executor.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/communication.h"
#include "paddle/fluid/pybind/compatible.h"
#include "paddle/fluid/pybind/compatible.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/data_set_py.h"
...
@@ -2621,6 +2622,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2621,6 +2622,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGlobalValueGetterSetter
(
&
m
);
BindGlobalValueGetterSetter
(
&
m
);
BindProcessMeshDesc
(
&
m
);
BindProcessMeshDesc
(
&
m
);
BindFleetExecutor
(
&
m
);
BindFleetExecutor
(
&
m
);
BindTCPStore
(
&
m
);
py
::
class_
<
framework
::
LoDRankTable
>
(
m
,
"LodRankTable"
)
py
::
class_
<
framework
::
LoDRankTable
>
(
m
,
"LodRankTable"
)
.
def
(
"items"
,
[](
framework
::
LoDRankTable
&
table
)
{
.
def
(
"items"
,
[](
framework
::
LoDRankTable
&
table
)
{
...
...
python/paddle/fluid/tests/unittests/test_tcp_store.py
0 → 100644
浏览文件 @
b95cd3b7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
datetime
import
paddle
class
TestTCPStore
(
unittest
.
TestCase
):
def
test_tcp_store
(
self
):
store
=
paddle
.
fluid
.
core
.
TCPStore
(
"127.0.0.1"
,
6170
,
True
,
1
,
datetime
.
timedelta
(
0
))
store
.
add
(
"my"
,
3
)
ret1
=
store
.
get
(
'my'
)
store
.
add
(
"my"
,
3
)
ret2
=
store
.
get
(
'my'
)
self
.
assertEqual
(
ret1
[
0
]
+
3
,
ret2
[
0
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录