Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9025fddd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9025fddd
编写于
2月 17, 2023
作者:
W
Wen Sun
提交者:
GitHub
2月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rpc ops to fetch data from remote service (#50220)
上级
0699afb1
变更
29
展开全部
显示空白变更内容
内联
并排
Showing
29 changed file
with
25936 addition
and
32 deletion
+25936
-32
.pre-commit-config.yaml
.pre-commit-config.yaml
+8
-2
cmake/generic.cmake
cmake/generic.cmake
+1
-1
cmake/third_party.cmake
cmake/third_party.cmake
+12
-0
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-2
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+1
-1
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+4
-4
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+3
-3
paddle/fluid/distributed/fleet_executor/message_service.cc
paddle/fluid/distributed/fleet_executor/message_service.cc
+1
-1
paddle/fluid/distributed/fleet_executor/message_service.h
paddle/fluid/distributed/fleet_executor/message_service.h
+1
-1
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+1
-3
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-10
paddle/fluid/operators/collective/CMakeLists.txt
paddle/fluid/operators/collective/CMakeLists.txt
+7
-0
paddle/fluid/operators/collective/rpc_call_op.cc
paddle/fluid/operators/collective/rpc_call_op.cc
+67
-0
paddle/fluid/operators/collective/rpc_call_op.h
paddle/fluid/operators/collective/rpc_call_op.h
+184
-0
paddle/fluid/operators/collective/rpc_result_op.cc
paddle/fluid/operators/collective/rpc_result_op.cc
+62
-0
paddle/fluid/operators/collective/rpc_result_op.h
paddle/fluid/operators/collective/rpc_result_op.h
+172
-0
paddle/fluid/operators/collective/thirdparty/json.h
paddle/fluid/operators/collective/thirdparty/json.h
+24596
-0
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+18
-0
paddle/fluid/platform/rpc_utils.cc
paddle/fluid/platform/rpc_utils.cc
+312
-0
paddle/fluid/platform/rpc_utils.h
paddle/fluid/platform/rpc_utils.h
+176
-0
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
...on/paddle/fluid/tests/unittests/collective/CMakeLists.txt
+11
-0
python/paddle/fluid/tests/unittests/collective/py_server_test.py
...paddle/fluid/tests/unittests/collective/py_server_test.py
+42
-0
python/paddle/fluid/tests/unittests/collective/test_rpc_call_result.py
.../fluid/tests/unittests/collective/test_rpc_call_result.py
+107
-0
python/paddle/fluid/tests/unittests/collective/test_rpc_call_result.sh
.../fluid/tests/unittests/collective/test_rpc_call_result.sh
+15
-0
python/paddle/fluid/tests/unittests/collective/testslist.csv
python/paddle/fluid/tests/unittests/collective/testslist.csv
+1
-0
python/paddle/static/nn/__init__.py
python/paddle/static/nn/__init__.py
+5
-1
python/paddle/static/nn/rpc_utils.py
python/paddle/static/nn/rpc_utils.py
+112
-0
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+2
-2
tools/codestyle/cpplint_pre_commit.hook
tools/codestyle/cpplint_pre_commit.hook
+2
-1
未找到文件。
.pre-commit-config.yaml
浏览文件 @
9025fddd
...
@@ -18,6 +18,10 @@ repos:
...
@@ -18,6 +18,10 @@ repos:
rev
:
v4.1.0
rev
:
v4.1.0
hooks
:
hooks
:
-
id
:
check-added-large-files
-
id
:
check-added-large-files
exclude
:
|
(?x)^(
paddle/fluid/operators/collective/thirdparty/json.h
)$
-
id
:
check-merge-conflict
-
id
:
check-merge-conflict
-
id
:
check-symlinks
-
id
:
check-symlinks
-
id
:
detect-private-key
-
id
:
detect-private-key
...
@@ -35,7 +39,8 @@ repos:
...
@@ -35,7 +39,8 @@ repos:
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
exclude
:
|
exclude
:
|
(?x)^(
(?x)^(
paddle/fluid/distributed/ps/thirdparty/round_robin.h
paddle/fluid/distributed/ps/thirdparty/round_robin.h|
paddle/fluid/operators/collective/thirdparty/json.h
)$
)$
-
repo
:
local
-
repo
:
local
hooks
:
hooks
:
...
@@ -62,7 +67,8 @@ repos:
...
@@ -62,7 +67,8 @@ repos:
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude
:
|
exclude
:
|
(?x)^(
(?x)^(
paddle/utils/.*
paddle/utils/.*|
paddle/fluid/operators/collective/thirdparty/json.h
)$
)$
-
repo
:
local
-
repo
:
local
hooks
:
hooks
:
...
...
cmake/generic.cmake
浏览文件 @
9025fddd
...
@@ -96,7 +96,7 @@ if(NOT APPLE AND NOT WIN32)
...
@@ -96,7 +96,7 @@ if(NOT APPLE AND NOT WIN32)
link_libraries
(
${
CMAKE_THREAD_LIBS_INIT
}
)
link_libraries
(
${
CMAKE_THREAD_LIBS_INIT
}
)
if
(
WITH_PSLIB OR WITH_DISTRIBUTE
)
if
(
WITH_PSLIB OR WITH_DISTRIBUTE
)
set
(
CMAKE_CXX_LINK_EXECUTABLE
set
(
CMAKE_CXX_LINK_EXECUTABLE
"
${
CMAKE_CXX_LINK_EXECUTABLE
}
-pthread -ldl -lrt -lz -lssl"
)
"
${
CMAKE_CXX_LINK_EXECUTABLE
}
-pthread -ldl -lrt -lz -lssl
-lcrypto
"
)
else
()
else
()
set
(
CMAKE_CXX_LINK_EXECUTABLE
set
(
CMAKE_CXX_LINK_EXECUTABLE
"
${
CMAKE_CXX_LINK_EXECUTABLE
}
-pthread -ldl -lrt"
)
"
${
CMAKE_CXX_LINK_EXECUTABLE
}
-pthread -ldl -lrt"
)
...
...
cmake/third_party.cmake
浏览文件 @
9025fddd
...
@@ -424,6 +424,18 @@ if(WITH_PSCORE)
...
@@ -424,6 +424,18 @@ if(WITH_PSCORE)
list
(
APPEND third_party_deps extern_rocksdb
)
list
(
APPEND third_party_deps extern_rocksdb
)
endif
()
endif
()
if
(
WITH_DISTRIBUTE
AND NOT WITH_PSLIB
AND NOT WITH_PSCORE
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
include
(
external/leveldb
)
list
(
APPEND third_party_deps extern_leveldb
)
include
(
external/brpc
)
list
(
APPEND third_party_deps extern_brpc
)
endif
()
if
(
WITH_XBYAK
)
if
(
WITH_XBYAK
)
include
(
external/xbyak
)
# download, build, install xbyak
include
(
external/xbyak
)
# download, build, install xbyak
list
(
APPEND third_party_deps extern_xbyak
)
list
(
APPEND third_party_deps extern_xbyak
)
...
...
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
9025fddd
add_subdirectory
(
auto_parallel
)
add_subdirectory
(
auto_parallel
)
add_subdirectory
(
collective
)
add_subdirectory
(
collective
)
add_subdirectory
(
store
)
add_subdirectory
(
store
)
add_subdirectory
(
fleet_executor
)
if
(
WITH_PYTHON
)
if
(
WITH_PYTHON
)
py_proto_compile
(
ps_py_proto SRCS the_one_ps.proto
)
py_proto_compile
(
ps_py_proto SRCS the_one_ps.proto
)
add_custom_target
(
add_custom_target
(
...
@@ -29,7 +30,6 @@ if(WITH_PYTHON)
...
@@ -29,7 +30,6 @@ if(WITH_PYTHON)
endif
()
endif
()
if
(
NOT WITH_PSCORE
)
if
(
NOT WITH_PSCORE
)
add_subdirectory
(
fleet_executor
)
return
()
return
()
endif
()
endif
()
...
@@ -47,4 +47,3 @@ add_subdirectory(common)
...
@@ -47,4 +47,3 @@ add_subdirectory(common)
add_subdirectory
(
ps
)
add_subdirectory
(
ps
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
add_subdirectory
(
index_dataset
)
add_subdirectory
(
index_dataset
)
add_subdirectory
(
fleet_executor
)
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -6,7 +6,7 @@ proto_library(interceptor_message_proto SRCS interceptor_message.proto)
...
@@ -6,7 +6,7 @@ proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if
(
WITH_ARM_BRPC
)
if
(
WITH_ARM_BRPC
)
set
(
BRPC_DEPS arm_brpc snappy gflags glog
)
set
(
BRPC_DEPS arm_brpc snappy gflags glog
)
elseif
(
WITH_DISTRIBUTE
AND WITH_PSCORE
)
elseif
(
WITH_DISTRIBUTE
)
set
(
BRPC_DEPS
set
(
BRPC_DEPS
brpc
brpc
ssl
ssl
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
9025fddd
...
@@ -73,7 +73,7 @@ bool MessageBus::IsInit() const { return is_init_; }
...
@@ -73,7 +73,7 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus
::~
MessageBus
()
{
MessageBus
::~
MessageBus
()
{
VLOG
(
3
)
<<
"Message bus releases resource."
;
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
server_
.
Stop
(
1000
);
server_
.
Stop
(
1000
);
server_
.
Join
();
server_
.
Join
();
#endif
#endif
...
@@ -94,7 +94,7 @@ bool MessageBus::Send(int64_t dst_rank,
...
@@ -94,7 +94,7 @@ bool MessageBus::Send(int64_t dst_rank,
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized."
));
"Using message bus since it has not been initialized."
));
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
while
(
retry_time
<
10
)
{
++
retry_time
;
++
retry_time
;
...
@@ -179,7 +179,7 @@ void MessageBus::ListenPort() {
...
@@ -179,7 +179,7 @@ void MessageBus::ListenPort() {
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
return
;
return
;
}
}
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
// function keep listen the port and handle the message
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
server_
.
AddService
(
&
message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
server_
.
AddService
(
&
message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
...
@@ -209,7 +209,7 @@ void MessageBus::ListenPort() {
...
@@ -209,7 +209,7 @@ void MessageBus::ListenPort() {
#endif
#endif
}
}
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
const
InterceptorMessage
&
interceptor_message
)
{
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
9025fddd
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include <thread>
#include <thread>
#include <unordered_map>
#include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#include "brpc/channel.h"
#include "brpc/channel.h"
#include "brpc/server.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
...
@@ -63,7 +63,7 @@ class MessageBus final {
...
@@ -63,7 +63,7 @@ class MessageBus final {
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
// send the message inter rank (dst is different rank with src)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
int64_t
dst_rank
,
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
const
InterceptorMessage
&
interceptor_message
);
...
@@ -79,7 +79,7 @@ class MessageBus final {
...
@@ -79,7 +79,7 @@ class MessageBus final {
// the ip needs to be listened
// the ip needs to be listened
std
::
string
addr_
;
std
::
string
addr_
;
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
MessageServiceImpl
message_service_
;
MessageServiceImpl
message_service_
;
// brpc server
// brpc server
brpc
::
Server
server_
;
brpc
::
Server
server_
;
...
...
paddle/fluid/distributed/fleet_executor/message_service.cc
浏览文件 @
9025fddd
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "brpc/server.h"
...
...
paddle/fluid/distributed/fleet_executor/message_service.h
浏览文件 @
9025fddd
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE)
&& defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#pragma once
#pragma once
#include "brpc/server.h"
#include "brpc/server.h"
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -59,9 +59,7 @@ cc_test(
...
@@ -59,9 +59,7 @@ cc_test(
scope
scope
device_context
)
device_context
)
if
(
WITH_DISTRIBUTE
if
(
WITH_DISTRIBUTE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
AND WITH_PSCORE
AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set_source_files_properties
(
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc
interceptor_ping_pong_with_brpc_test.cc
PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -584,6 +584,7 @@ if(WITH_PYTHON)
...
@@ -584,6 +584,7 @@ if(WITH_PYTHON)
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
if
(
NOT WITH_ROCM
)
add_custom_target
(
add_custom_target
(
fleet_executor_proto_init ALL
fleet_executor_proto_init ALL
DEPENDS fleet_proto_init fleet_executor_desc_py_proto
DEPENDS fleet_proto_init fleet_executor_desc_py_proto
...
@@ -594,6 +595,7 @@ if(WITH_PYTHON)
...
@@ -594,6 +595,7 @@ if(WITH_PYTHON)
COMMENT
COMMENT
"Copy generated python proto into directory paddle/distributed/fleet/proto."
"Copy generated python proto into directory paddle/distributed/fleet/proto."
)
)
endif
()
else
()
else
()
string
(
REPLACE
"/"
"
\\
"
proto_dstpath
string
(
REPLACE
"/"
"
\\
"
proto_dstpath
"
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/"
)
"
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/"
)
...
...
paddle/fluid/operators/collective/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -30,9 +30,16 @@ register_operators(
...
@@ -30,9 +30,16 @@ register_operators(
c_gen_hccl_id_op
c_gen_hccl_id_op
gen_hccl_id_op
gen_hccl_id_op
c_gen_cncl_id_op
c_gen_cncl_id_op
rpc_call_op
rpc_result_op
DEPS
DEPS
${
COLLECTIVE_DEPS
}
)
${
COLLECTIVE_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
op_library
(
rpc_call_op DEPS rpc_utils
${
COLLECTIVE_DEPS
}
)
op_library
(
rpc_result_op DEPS rpc_utils
${
COLLECTIVE_DEPS
}
)
endif
()
if
(
WITH_NCCL OR WITH_RCCL
)
if
(
WITH_NCCL OR WITH_RCCL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
nccl_common collective_helper
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
nccl_common collective_helper
)
op_library
(
c_gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
op_library
(
c_gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
...
...
paddle/fluid/operators/collective/rpc_call_op.cc
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/rpc_call_op.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
RpcCallOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
dtype
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
dtype
,
ctx
.
GetPlace
());
}
};
class
RpcCallOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Src words' ids."
);
AddOutput
(
"Out"
,
"(Tensor) Request id."
);
AddAttr
<
std
::
string
>
(
"url"
,
"URL."
).
SetDefault
({});
AddAttr
<
std
::
string
>
(
"vocab_path"
,
"Vocab's absolute path."
).
SetDefault
(
""
);
AddAttr
<
bool
>
(
"use_ids"
,
"If true, use ids directly."
).
SetDefault
(
true
);
AddAttr
<
int
>
(
"timeout"
,
"rpc connection timeout ms"
).
SetDefault
(
3000
);
AddAttr
<
int
>
(
"retry"
,
"rpc connection retry time"
).
SetDefault
(
100
);
AddComment
(
R"DOC(
Rpc Call Operator
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
rpc_call
,
ops
::
RpcCallOp
,
ops
::
RpcCallOpMaker
);
REGISTER_OP_CPU_KERNEL
(
rpc_call
,
ops
::
RpcCallOpKernel
<
int
>
,
ops
::
RpcCallOpKernel
<
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
rpc_call
,
ops
::
RpcCallOpKernel
<
int
>
,
ops
::
RpcCallOpKernel
<
int64_t
>
);
paddle/fluid/operators/collective/rpc_call_op.h
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <brpc/channel.h>
#include <fstream>
#include <memory>
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/collective/thirdparty/json.h"
#include "paddle/fluid/platform/rpc_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
operators
{
#define DATA_STRLIST 0
/*
{"data": ["你好"]}
*/
#define TEXT_STR 1
/*
{"text": "nihao"}
*/
using
json
=
nlohmann
::
json
;
// payload builders
template
<
typename
T
=
int64_t
>
static
inline
std
::
string
BuildIdsPayload
(
const
std
::
vector
<
T
>&
src_ids
)
{
json
payload
=
{{
"ids"
,
src_ids
}};
// => {"ids": [1, 2, 3, ...]}
return
payload
.
dump
();
}
static
inline
std
::
string
BuildStrPayload
(
const
std
::
string
&
query
,
int
build_way
)
{
json
payload
;
switch
(
build_way
)
{
case
DATA_STRLIST
:
payload
=
{{
"data"
,
{
query
}}};
//=> {"data": [query]}
break
;
case
TEXT_STR
:
payload
=
{{
"text"
,
query
}};
//=> {"text": query}
break
;
default:
break
;
}
return
payload
.
dump
();
}
template
<
typename
T
=
int64_t
>
static
inline
std
::
string
BuildPayload
(
const
std
::
string
&
service
,
const
std
::
vector
<
T
>&
src_ids
)
{
if
(
service
==
"ids"
)
{
return
BuildIdsPayload
(
src_ids
);
}
else
if
(
service
==
"str"
)
{
const
std
::
string
query
=
platform
::
RpcTokenizer
::
Instance
().
GetWordsFromIds
(
src_ids
);
return
BuildStrPayload
(
query
,
TEXT_STR
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unknown service."
));
}
}
// req & res handlers
static
inline
void
HandleServiceRequest
(
brpc
::
Controller
*
ctrl
,
int
request_id
,
const
std
::
string
&
payload
)
{
ctrl
->
request_attachment
().
append
(
payload
);
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
"payload size:"
<<
payload
.
size
();
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
" payload: "
<<
payload
;
}
static
inline
void
HandleServiceResponse
(
brpc
::
Controller
*
ctrl
,
int
request_id
,
std
::
shared_ptr
<
bthread
::
CountdownEvent
>
event
)
{
// make sure the controller will be deleted
std
::
unique_ptr
<
brpc
::
Controller
>
ctrl_guard
(
ctrl
);
auto
&
rpc_store
=
platform
::
RpcRequestStore
::
Instance
();
if
(
ctrl
->
Failed
())
{
rpc_store
.
InsertErrorCode
(
request_id
,
ctrl
->
ErrorCode
());
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Request id %s failed: access url error. error code: %d, http code: %d"
,
request_id
,
ctrl
->
ErrorCode
(),
ctrl
->
http_response
().
status_code
()));
}
else
{
const
std
::
string
res
=
ctrl
->
response_attachment
().
to_string
();
rpc_store
.
InsertErrorCode
(
request_id
,
0
);
rpc_store
.
InsertResponse
(
request_id
,
res
);
}
// try to notify result op
event
->
signal
();
}
static
int
send_sequence
(
const
framework
::
ExecutionContext
&
ctx
,
const
std
::
string
&
service
,
const
phi
::
DenseTensor
&
src_ids_tensor
,
const
std
::
string
&
url
,
const
int
&
timeout
=
3000
,
const
int
&
retry
=
100
)
{
std
::
vector
<
int
>
src_ids_vec
;
framework
::
TensorToVector
(
src_ids_tensor
,
ctx
.
device_context
(),
&
src_ids_vec
);
const
std
::
string
payload
=
BuildPayload
(
service
,
src_ids_vec
);
int
request_id
=
platform
::
RpcCommContext
::
RpcSend
(
url
,
payload
,
&
HandleServiceRequest
,
&
HandleServiceResponse
,
brpc
::
HttpMethod
::
HTTP_METHOD_POST
,
timeout
,
retry
);
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
" url: "
<<
url
;
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
" payload: "
<<
payload
;
return
request_id
;
}
template
<
typename
T
>
class
RpcCallOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// url, assume num of urls is limited
const
std
::
string
url
=
ctx
.
Attr
<
std
::
string
>
(
"url"
);
// payload
auto
src_ids_tensor
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
x_dims
=
src_ids_tensor
->
dims
();
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
platform
::
errors
::
PreconditionNotMet
(
"The input src ids' dim size must be 2. However the dim is %d"
,
x_dims
.
size
()));
std
::
vector
<
int
>
request_ids
(
x_dims
[
0
]);
bool
use_ids
=
ctx
.
Attr
<
bool
>
(
"use_ids"
);
std
::
string
service
;
if
(
use_ids
)
{
service
=
"ids"
;
}
else
{
// init tokenizer
auto
vocab_path
=
ctx
.
Attr
<
std
::
string
>
(
"vocab_path"
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
special
;
platform
::
RpcTokenizer
::
Instance
().
Init
(
vocab_path
,
special
);
service
=
"str"
;
}
int
timeout
=
ctx
.
Attr
<
int
>
(
"timeout"
);
int
retry
=
ctx
.
Attr
<
int
>
(
"retry"
);
for
(
auto
i
=
0
;
i
<
x_dims
[
0
];
i
++
)
{
request_ids
[
i
]
=
send_sequence
(
ctx
,
service
,
src_ids_tensor
->
Slice
(
i
,
i
+
1
),
url
,
timeout
,
retry
);
}
auto
*
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
out
->
Resize
({
static_cast
<
int64_t
>
(
request_ids
.
size
())});
ctx
.
device_context
().
Alloc
<
int
>
(
out
);
framework
::
TensorFromVector
(
request_ids
,
ctx
.
device_context
(),
out
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/rpc_result_op.cc
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/rpc_result_op.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
operators
{
class
RpcResultOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
dtype
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
dtype
,
ctx
.
GetPlace
());
}
};
class
RpcResultOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Request id."
);
AddOutput
(
"Out"
,
"(Tensor) Response from service."
);
AddOutput
(
"succeed"
,
"Request status, true means succeed."
);
AddAttr
<
std
::
string
>
(
"res_type"
,
"Result type returns."
)
.
SetDefault
(
"float"
);
AddComment
(
R"DOC(
Rpc Result Operator
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
rpc_result
,
ops
::
RpcResultOp
,
ops
::
RpcResultOpMaker
);
REGISTER_OP_CPU_KERNEL
(
rpc_result
,
ops
::
RpcResultOpKernel
<
int
>
);
REGISTER_OP_CUDA_KERNEL
(
rpc_result
,
ops
::
RpcResultOpKernel
<
int
>
);
paddle/fluid/operators/collective/rpc_result_op.h
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/collective/thirdparty/json.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/rpc_utils.h"
namespace
paddle
{
namespace
operators
{
using
json
=
nlohmann
::
json
;
#define PARSE_DIRECT_FLOAT 0
/*
1.23
*/
#define PARSE_RESULT_FLOAT 1
/*
{"result": ["1.23"]}
*/
static
inline
std
::
vector
<
float
>
ParseFloatResponse
(
const
std
::
string
&
response
,
int
parse_way
)
{
auto
obj
=
json
::
parse
(
response
);
switch
(
parse_way
)
{
case
PARSE_RESULT_FLOAT
:
{
auto
res
=
obj
[
"result"
][
0
].
get
<
std
::
string
>
();
return
{
std
::
stof
(
res
,
nullptr
)};
}
case
PARSE_DIRECT_FLOAT
:
return
{
obj
.
get
<
float
>
()};
default:
break
;
}
return
{
static_cast
<
float
>
(
0
)};
}
static
inline
std
::
vector
<
uint8_t
>
ParseStrResponse
(
const
std
::
string
&
response
)
{
const
std
::
string
res
=
json
::
parse
(
response
).
dump
();
return
std
::
vector
<
uint8_t
>
(
res
.
begin
(),
res
.
end
());
}
static
std
::
vector
<
uint8_t
>
get_str_response
(
const
int
&
request_id
)
{
// wait for call op's event notification
auto
&
rpc_store
=
platform
::
RpcRequestStore
::
Instance
();
auto
event
=
rpc_store
.
GetEvent
(
request_id
);
int
err_code
=
rpc_store
.
GetErrorCode
(
request_id
);
bool
ok
=
event
->
wait
()
==
0
&&
err_code
==
0
;
if
(
ok
)
{
const
std
::
string
&
resp
=
rpc_store
.
GetResponse
(
request_id
);
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
" raw response: "
<<
resp
;
VLOG
(
3
)
<<
"Request id "
<<
request_id
;
// auto out_ = const_cast<phi::DenseTensor&>(out);
auto
out_vector
=
ParseStrResponse
(
resp
);
return
out_vector
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Request %s failed with error code %s."
,
request_id
,
err_code
));
}
}
static
std
::
vector
<
float
>
get_float_response
(
const
int
&
request_id
)
{
// wait for call op's event notification
auto
&
rpc_store
=
platform
::
RpcRequestStore
::
Instance
();
auto
event
=
rpc_store
.
GetEvent
(
request_id
);
int
err_code
=
rpc_store
.
GetErrorCode
(
request_id
);
bool
ok
=
event
->
wait
()
==
0
&&
err_code
==
0
;
if
(
ok
)
{
const
std
::
string
&
resp
=
rpc_store
.
GetResponse
(
request_id
);
VLOG
(
3
)
<<
"Request id "
<<
request_id
<<
" raw response: "
<<
resp
;
VLOG
(
3
)
<<
"Request id "
<<
request_id
;
// auto out_ = const_cast<phi::DenseTensor&>(out);
auto
out_vector
=
ParseFloatResponse
(
resp
,
PARSE_RESULT_FLOAT
);
return
out_vector
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Request %s failed with error code %s."
,
request_id
,
err_code
));
}
}
template
<
typename
T
>
class
RpcResultOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
request_id_tensor
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
std
::
vector
<
int
>
request_id_tensor_vec
;
framework
::
TensorToVector
(
*
request_id_tensor
,
ctx
.
device_context
(),
&
request_id_tensor_vec
);
auto
*
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
const
std
::
string
res_type
=
ctx
.
Attr
<
std
::
string
>
(
"res_type"
);
VLOG
(
3
)
<<
"out dims: "
<<
out
->
dims
().
to_str
()
<<
"numel: "
<<
out
->
numel
();
if
(
res_type
==
"str"
)
{
ctx
.
device_context
().
Alloc
<
uint8_t
>
(
out
);
}
else
if
(
res_type
==
"float"
)
{
ctx
.
device_context
().
Alloc
<
float
>
(
out
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unknown result type. error type: %s"
,
res_type
.
c_str
()));
}
VLOG
(
3
)
<<
"out dims: "
<<
out
->
dims
().
to_str
();
std
::
vector
<
std
::
vector
<
uint8_t
>>
uint8_vec
;
std
::
vector
<
std
::
vector
<
float
>>
float_vec
;
int64_t
max_size
=
-
1
;
for
(
auto
i
=
0
;
i
<
request_id_tensor
->
dims
()[
0
];
i
++
)
{
if
(
res_type
==
"float"
)
{
auto
vec
=
get_float_response
(
request_id_tensor_vec
[
i
]);
max_size
=
std
::
max
(
max_size
,
static_cast
<
int64_t
>
(
vec
.
size
()));
float_vec
.
emplace_back
(
vec
);
}
else
if
(
res_type
==
"str"
)
{
auto
vec
=
get_str_response
(
request_id_tensor_vec
[
i
]);
uint8_vec
.
emplace_back
(
vec
);
max_size
=
std
::
max
(
max_size
,
static_cast
<
int64_t
>
(
vec
.
size
()));
PADDLE_ENFORCE_LE
(
max_size
,
100
*
1024
*
1024
,
platform
::
errors
::
Unavailable
(
"to many string data, exceed 100MB"
));
}
}
out
->
Resize
({
request_id_tensor
->
dims
()[
0
],
max_size
});
if
(
res_type
==
"str"
)
{
ctx
.
device_context
().
Alloc
<
uint8_t
>
(
out
);
for
(
size_t
i
=
0
;
i
<
uint8_vec
.
size
();
i
++
)
{
phi
::
DenseTensor
out_
=
out
->
Slice
(
i
,
i
+
1
);
for
(
int
k
=
uint8_vec
[
i
].
size
();
k
<
max_size
;
k
++
)
{
uint8_vec
[
i
].
emplace_back
(
static_cast
<
uint8_t
>
(
0
));
}
framework
::
TensorFromVector
(
uint8_vec
[
i
],
ctx
.
device_context
(),
&
out_
);
}
}
else
if
(
res_type
==
"float"
)
{
ctx
.
device_context
().
Alloc
<
float
>
(
out
);
for
(
size_t
i
=
0
;
i
<
float_vec
.
size
();
i
++
)
{
phi
::
DenseTensor
out_
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
TensorFromVector
(
float_vec
[
i
],
ctx
.
device_context
(),
&
out_
);
}
}
auto
*
succeed
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"succeed"
);
ctx
.
device_context
().
Alloc
<
bool
>
(
succeed
);
std
::
vector
<
bool
>
succeed_wrapper
{
true
};
framework
::
TensorFromVector
(
succeed_wrapper
,
ctx
.
device_context
(),
succeed
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/thirdparty/json.h
0 → 100644
浏览文件 @
9025fddd
此差异已折叠。
点击以展开。
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -223,6 +223,24 @@ cc_library(
...
@@ -223,6 +223,24 @@ cc_library(
phi_device_context
phi_device_context
generator
)
generator
)
if
(
WITH_DISTRIBUTE
)
set
(
BRPC_DEPS
brpc
ssl
crypto
protobuf
zlib
leveldb
snappy
gflags
glog
)
cc_library
(
rpc_utils
SRCS rpc_utils.cc
DEPS enforce
${
BRPC_DEPS
}
)
endif
()
cc_library
(
cc_library
(
collective_helper
collective_helper
SRCS collective_helper.cc gen_comm_id_helper.cc
SRCS collective_helper.cc gen_comm_id_helper.cc
...
...
paddle/fluid/platform/rpc_utils.cc
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/rpc_utils.h"
#include <algorithm>
#include <fstream>
#include <regex>
#include <sstream>
#include <unordered_set>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
platform
{
// globals
static
std
::
wstring_convert
<
std
::
codecvt_utf8
<
wchar_t
>
,
wchar_t
>
converter
;
// utils
static
inline
bool
StartsWith
(
const
std
::
string
&
str
,
const
std
::
string
&
prefix
)
{
return
str
.
substr
(
0
,
prefix
.
length
())
==
prefix
;
}
static
inline
bool
EndsWith
(
const
std
::
string
&
str
,
const
std
::
string
&
suffix
)
{
return
str
.
length
()
>=
suffix
.
length
()
&&
str
.
substr
(
str
.
length
()
-
suffix
.
length
())
==
suffix
;
}
static
inline
std
::
string
Replace
(
const
std
::
string
&
str
,
const
std
::
string
&
old_str
,
const
std
::
string
&
new_str
)
{
if
(
old_str
==
new_str
)
{
return
str
;
}
std
::
stringstream
ss
;
size_t
start_pos
=
0
;
size_t
pos
=
str
.
find
(
old_str
,
start_pos
);
while
(
pos
!=
std
::
string
::
npos
)
{
ss
<<
str
.
substr
(
start_pos
,
pos
-
start_pos
)
<<
new_str
;
start_pos
=
pos
+
old_str
.
size
();
pos
=
str
.
find
(
old_str
,
start_pos
);
}
ss
<<
str
.
substr
(
start_pos
);
return
ss
.
str
();
}
static
inline
bool
IsChineseChar
(
wchar_t
c
)
{
return
(
c
>=
0x4E00
&&
c
<=
0x9FFF
)
||
(
c
>=
0x3400
&&
c
<=
0x4DBF
)
||
(
c
>=
0x20000
&&
c
<=
0x2A6DF
)
||
(
c
>=
0x2A700
&&
c
<=
0x2B73F
)
||
(
c
>=
0x2B740
&&
c
<=
0x2B81F
)
||
(
c
>=
0x2B820
&&
c
<=
0x2CEAF
)
||
(
c
>=
0xF900
&&
c
<=
0xFAFF
)
||
(
c
>=
0x2F800
&&
c
<=
0x2FA1F
);
}
static
inline
bool
IsChinesePunct
(
wchar_t
c
)
{
std
::
unordered_set
<
wchar_t
>
puncts
=
{
L'!'
,
L'?'
,
L'。'
,
L'。'
,
L'"'
,
L'#'
,
L'$'
,
L'%'
,
L'&'
,
L'''
,
L'('
,
L')'
,
L'*'
,
L'+'
,
L','
,
L'-'
,
L'/'
,
L':'
,
L';'
,
L'<'
,
L'='
,
L'>'
,
L'@'
,
L'['
,
L'\'
,
L']'
,
L'^'
,
L'_'
,
L'`'
,
L'{'
,
L'|'
,
L'}'
,
L'~'
,
L'⦅'
,
L'⦆'
,
L'「'
,
L'」'
,
L'、'
,
L'、'
,
L'〃'
,
L'》'
,
L'「'
,
L'」'
,
L'『'
,
L'』'
,
L'【'
,
L'】'
,
L'〔'
,
L'〕'
,
L'〖'
,
L'〗'
,
L'〘'
,
L'〙'
,
L'〚'
,
L'〛'
,
L'〜'
,
L'〝'
,
L'〞'
,
L'〟'
,
L'〰'
,
L'〾'
,
L'〿'
,
L'–'
,
L'—'
,
L'“'
,
L'”'
,
L'‘'
,
L'’'
};
return
puncts
.
count
(
c
);
}
static
inline
int
GetCharBytes
(
uint8_t
byte
)
{
if
((
byte
&
0x80
)
==
0
)
{
return
1
;
}
else
if
((
byte
&
0xE0
)
==
0xC0
)
{
return
2
;
}
else
if
((
byte
&
0xF0
)
==
0xE0
)
{
return
3
;
}
else
if
((
byte
&
0xF8
)
==
0xF0
)
{
return
4
;
}
else
{
return
-
1
;
}
}
static
inline
bool
IsValidContinuationByte
(
uint8_t
byte
)
{
// check if the byte starts with 10
return
(
byte
&
0xC0
)
==
0x80
;
}
static
inline
uint8_t
GetByteFromHex
(
const
std
::
string
&
token
)
{
auto
num_str
=
paddle
::
string
::
split_string
(
token
,
"_"
)[
1
];
num_str
=
num_str
.
substr
(
0
,
num_str
.
size
()
-
1
);
return
static_cast
<
uint8_t
>
(
std
::
stoi
(
num_str
,
nullptr
,
16
));
}
// RpcTokenizer
void
RpcTokenizer
::
Init
(
const
std
::
string
&
path
)
{
if
(
path_
==
path
)
{
return
;
}
std
::
ifstream
vocab_file
(
path
);
std
::
string
word
;
int
id
;
while
(
vocab_file
>>
word
>>
id
)
{
ids_to_words_
.
emplace
(
id
,
word
);
words_to_ids_
.
emplace
(
converter
.
from_bytes
(
word
),
id
);
}
// update members
path_
=
path
;
}
void
RpcTokenizer
::
Init
(
const
std
::
string
&
path
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
special_set
)
{
if
(
path_
==
path
)
{
return
;
}
Init
(
path
);
SetSpecialSet
(
special_set
);
}
std
::
string
RpcTokenizer
::
GetRecoveredToken
(
const
std
::
vector
<
uint8_t
>&
bytes
)
{
std
::
string
res
;
int
n
=
bytes
.
size
();
int
i
=
0
;
while
(
i
<
n
)
{
int
sz
=
0
;
while
((
sz
=
GetCharBytes
(
bytes
[
i
]))
==
-
1
)
{
++
i
;
}
if
(
i
+
sz
<
n
)
{
std
::
vector
<
uint8_t
>
valid_bytes
;
valid_bytes
.
emplace_back
(
bytes
[
i
]);
for
(
int
j
=
1
;
j
<
sz
;
++
j
)
{
if
(
!
IsValidContinuationByte
(
bytes
[
i
]))
{
break
;
}
valid_bytes
.
emplace_back
(
bytes
[
i
]);
++
i
;
}
if
(
valid_bytes
.
size
()
==
static_cast
<
size_t
>
(
sz
))
{
res
+=
std
::
string
(
valid_bytes
.
begin
(),
valid_bytes
.
end
());
}
}
++
i
;
}
return
res
;
}
std
::
vector
<
std
::
string
>
RpcTokenizer
::
RecoverBFBTokens
(
const
std
::
vector
<
std
::
string
>&
tokens
)
{
std
::
vector
<
std
::
string
>
new_tokens
;
std
::
vector
<
uint8_t
>
tmp_bytes
;
for
(
const
auto
&
token
:
tokens
)
{
if
(
StartsWith
(
token
,
"[BFB"
))
{
tmp_bytes
.
emplace_back
(
GetByteFromHex
(
token
));
}
else
{
if
(
!
tmp_bytes
.
empty
())
{
// since there may be illegal bytes, we need this function
// if all bytes are legal, we can simply use string constructor
const
std
::
string
recovered_token
=
GetRecoveredToken
(
tmp_bytes
);
if
(
!
recovered_token
.
empty
())
{
new_tokens
.
emplace_back
(
recovered_token
);
}
}
if
(
token
!=
"[UNK]"
)
{
new_tokens
.
emplace_back
(
token
);
}
tmp_bytes
.
clear
();
}
}
if
(
!
tmp_bytes
.
empty
())
{
const
std
::
string
recovered_token
=
GetRecoveredToken
(
tmp_bytes
);
if
(
!
recovered_token
.
empty
())
{
new_tokens
.
emplace_back
(
recovered_token
);
}
}
return
new_tokens
;
}
std
::
vector
<
std
::
string
>
RpcTokenizer
::
PostProcess
(
const
std
::
vector
<
std
::
string
>&
tokens
,
const
WordToIdMap
&
vocab
,
bool
aggressive_break
,
const
std
::
string
&
stop_token
)
{
std
::
unordered_set
<
std
::
string
>
break_words
;
if
(
aggressive_break
)
{
break_words
=
{
"[END]"
,
"[gEND]"
,
"[<S>]"
,
"[UNK]"
,
"[CLS]"
};
}
else
{
break_words
=
{
"[END]"
,
"[gEND]"
};
}
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
replace_words
{
{
"[<S>]"
,
" "
},
{
"[<N>]"
,
"
\n
"
},
{
"[<T>]"
,
"
\t
"
},
{
"[<t>]"
,
" "
},
};
std
::
vector
<
std
::
string
>
new_text
;
auto
words
=
RecoverBFBTokens
(
tokens
);
for
(
auto
&
word
:
words
)
{
if
(
break_words
.
count
(
word
)
||
word
==
stop_token
)
{
break
;
}
if
(
word
.
empty
()
||
word
==
"[PAD]"
)
{
continue
;
}
if
(
replace_words
.
count
(
word
))
{
new_text
.
emplace_back
(
replace_words
.
at
(
word
));
continue
;
}
auto
unicode_word
=
converter
.
from_bytes
(
word
);
bool
is_chinese_char
=
IsChineseChar
(
unicode_word
[
0
]);
bool
is_chinese_punct
=
IsChinesePunct
(
unicode_word
[
0
]);
if
(
is_chinese_char
||
is_chinese_punct
||
vocab
.
count
(
unicode_word
)
==
0
)
{
if
(
!
new_text
.
empty
()
&&
EndsWith
(
new_text
.
back
(),
"@@"
))
{
auto
&
last_word
=
new_text
.
back
();
last_word
=
Replace
(
last_word
,
"@@"
,
""
);
}
new_text
.
emplace_back
(
word
);
}
else
if
(
!
StartsWith
(
word
,
"##"
))
{
if
(
!
new_text
.
empty
()
&&
EndsWith
(
new_text
.
back
(),
"@@"
))
{
auto
&
last_word
=
new_text
.
back
();
last_word
=
Replace
(
last_word
,
"@@"
,
""
);
new_text
.
emplace_back
(
word
);
}
else
if
(
!
new_text
.
empty
()
&&
EndsWith
(
new_text
.
back
(),
"
\n
"
))
{
new_text
.
emplace_back
(
word
);
}
else
{
if
(
!
new_text
.
empty
()
&&
!
new_text
.
back
().
empty
()
&&
IsChineseChar
(
converter
.
from_bytes
(
new_text
.
back
())[
0
]))
{
new_text
.
emplace_back
(
word
);
}
else
{
if
(
!
new_text
.
empty
())
{
new_text
.
emplace_back
(
" "
);
}
new_text
.
emplace_back
(
word
);
}
}
}
else
{
if
(
!
new_text
.
empty
()
&&
EndsWith
(
new_text
.
back
(),
"@@"
))
{
auto
&
last_word
=
new_text
.
back
();
last_word
=
last_word
.
substr
(
0
,
last_word
.
size
()
-
2
);
}
new_text
.
emplace_back
(
Replace
(
word
,
"##"
,
""
));
}
}
if
(
!
new_text
.
empty
())
{
auto
&
last_word
=
new_text
.
back
();
last_word
=
Replace
(
last_word
,
"@@"
,
""
);
}
return
new_text
;
}
int
RpcCommContext
::
RpcSend
(
const
std
::
string
&
url
,
const
std
::
string
&
query
,
void
(
*
request_handler
)(
brpc
::
Controller
*
,
int
,
const
std
::
string
&
),
void
(
*
response_handler
)(
brpc
::
Controller
*
,
int
,
std
::
shared_ptr
<
bthread
::
CountdownEvent
>
),
brpc
::
HttpMethod
http_method
,
int
timeout_ms
,
int
max_retry
)
{
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"http"
;
options
.
timeout_ms
=
timeout_ms
;
options
.
max_retry
=
max_retry
;
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
url
.
c_str
(),
/*load_balancer*/
""
,
&
options
),
0
,
phi
::
errors
::
Unavailable
(
"Rpc send failed: init brpc channel error."
));
auto
&
rpc_store
=
RpcRequestStore
::
Instance
();
int
request_id
=
rpc_store
.
GetRequestId
();
auto
event
=
std
::
make_shared
<
bthread
::
CountdownEvent
>
();
RpcRequestStore
::
Instance
().
InsertEvent
(
request_id
,
event
);
// if req is async, controller should be on heap to avoid deleting
auto
*
ctrl
=
new
brpc
::
Controller
();
ctrl
->
http_request
().
uri
()
=
url
.
c_str
();
ctrl
->
http_request
().
set_method
(
http_method
);
ctrl
->
http_request
().
SetHeader
(
"Content-Type"
,
"application/json"
);
request_handler
(
ctrl
,
request_id
,
query
);
channel
.
CallMethod
(
nullptr
,
ctrl
,
nullptr
,
nullptr
,
brpc
::
NewCallback
(
response_handler
,
ctrl
,
request_id
,
event
));
return
request_id
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/rpc_utils.h
0 → 100644
浏览文件 @
9025fddd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <brpc/channel.h>
#include <bthread/countdown_event.h>
#include <atomic>
#include <codecvt>
#include <locale>
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/utils/string/string_helper.h"
namespace
paddle
{
namespace
platform
{
using
WordToIdMap
=
std
::
unordered_map
<
std
::
wstring
,
int64_t
>
;
using
IdToWordMap
=
std
::
unordered_map
<
int64_t
,
std
::
string
>
;
class
RpcTokenizer
{
public:
static
RpcTokenizer
&
Instance
()
{
static
RpcTokenizer
instance
;
return
instance
;
}
void
Init
(
const
std
::
string
&
path
);
void
Init
(
const
std
::
string
&
path
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
special_set
);
void
SetSpecialSet
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
special_set
)
{
special_set_
=
special_set
;
}
bool
Contains
(
int64_t
id
)
{
return
ids_to_words_
.
count
(
id
)
>
0
;
}
// NOTE: an exception will be raised if id not exist
std
::
string
GetWordFromId
(
int64_t
id
)
{
auto
q
=
ids_to_words_
.
at
(
id
);
if
(
special_set_
.
count
(
q
)
==
1
)
{
return
special_set_
.
at
(
q
);
}
else
{
return
q
;
}
}
template
<
typename
T
=
int64_t
>
std
::
string
GetWordsFromIds
(
const
std
::
vector
<
T
>&
ids
,
bool
aggressive_break
=
false
,
const
std
::
string
&
stop_token
=
"[gEND]"
)
{
std
::
vector
<
std
::
string
>
tokens
;
for
(
auto
id
:
ids
)
{
if
(
!
Contains
(
id
))
{
continue
;
}
tokens
.
emplace_back
(
GetWordFromId
(
id
));
}
return
paddle
::
string
::
join_strings
(
PostProcess
(
tokens
,
words_to_ids_
,
aggressive_break
,
stop_token
),
""
);
}
// NOTE: an exception will be raised if word not exist
int64_t
GetIdFromWord
(
const
std
::
wstring
&
word
)
{
return
words_to_ids_
.
at
(
word
);
}
private:
std
::
string
GetRecoveredToken
(
const
std
::
vector
<
uint8_t
>&
bytes
);
std
::
vector
<
std
::
string
>
RecoverBFBTokens
(
const
std
::
vector
<
std
::
string
>&
tokens
);
std
::
vector
<
std
::
string
>
PostProcess
(
const
std
::
vector
<
std
::
string
>&
tokens
,
const
WordToIdMap
&
vocab
,
bool
aggressive_break
=
false
,
const
std
::
string
&
stop_token
=
"[gEND]"
);
private:
std
::
wstring_convert
<
std
::
codecvt_utf8
<
wchar_t
>
,
wchar_t
>
converter_
;
std
::
string
path_
;
IdToWordMap
ids_to_words_
;
WordToIdMap
words_to_ids_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
special_set_
;
};
class
RpcRequestStore
{
public:
static
RpcRequestStore
&
Instance
()
{
static
RpcRequestStore
instance
;
return
instance
;
}
int
GetRequestId
()
{
if
(
request_id_
==
INT32_MAX
)
{
request_id_
=
0
;
}
else
{
++
request_id_
;
}
return
request_id_
;
}
std
::
shared_ptr
<
bthread
::
CountdownEvent
>
GetEvent
(
int
request_id
)
{
return
id_to_event_map_
[
request_id
];
}
int
GetErrorCode
(
int
request_id
)
{
return
id_to_err_map_
[
request_id
];
}
std
::
string
GetResponse
(
int
request_id
)
{
return
id_to_resp_map_
[
request_id
];
}
void
InsertEvent
(
int
request_id
,
const
std
::
shared_ptr
<
bthread
::
CountdownEvent
>&
event
)
{
if
(
request_id
==
0
)
{
LOG
(
WARNING
)
<<
"Total num of requests have exceeded int limits."
;
}
id_to_event_map_
.
emplace
(
request_id
,
event
);
}
void
InsertErrorCode
(
int
request_id
,
int
error_code
)
{
if
(
request_id
==
0
)
{
LOG
(
WARNING
)
<<
"Total num of requests have exceeded int limits."
;
}
id_to_err_map_
.
emplace
(
request_id
,
error_code
);
}
void
InsertResponse
(
int
request_id
,
const
std
::
string
&
resp
)
{
if
(
request_id
==
0
)
{
LOG
(
WARNING
)
<<
"Total num of requests have exceeded int limits."
;
}
id_to_resp_map_
.
emplace
(
request_id
,
resp
);
}
private:
std
::
atomic
<
int
>
request_id_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
bthread
::
CountdownEvent
>>
id_to_event_map_
;
std
::
unordered_map
<
int
,
int
>
id_to_err_map_
;
std
::
unordered_map
<
int
,
std
::
string
>
id_to_resp_map_
;
};
struct
RpcCommContext
{
static
int
RpcSend
(
const
std
::
string
&
url
,
const
std
::
string
&
query
,
void
(
*
request_handler
)(
brpc
::
Controller
*
,
int
,
const
std
::
string
&
),
void
(
*
response_handler
)(
brpc
::
Controller
*
,
int
,
std
::
shared_ptr
<
bthread
::
CountdownEvent
>
),
brpc
::
HttpMethod
http_method
=
brpc
::
HttpMethod
::
HTTP_METHOD_POST
,
int
timeout_ms
=
10000
,
int
max_retry
=
3
);
};
}
// namespace platform
}
// namespace paddle
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
浏览文件 @
9025fddd
...
@@ -379,5 +379,16 @@ if((WITH_ROCM OR WITH_GPU) AND (LINUX))
...
@@ -379,5 +379,16 @@ if((WITH_ROCM OR WITH_GPU) AND (LINUX))
"PADDLE_DIST_UT_PORT=21532;http_proxy=;https_proxy="
)
"PADDLE_DIST_UT_PORT=21532;http_proxy=;https_proxy="
)
set_tests_properties
(
test_world_size_and_rank PROPERTIES TIMEOUT
"120"
)
set_tests_properties
(
test_world_size_and_rank PROPERTIES TIMEOUT
"120"
)
endif
()
endif
()
if
((
WITH_ROCM OR WITH_GPU
)
AND
(
LINUX
))
bash_test_modules
(
test_rpc_call_result
START_BASH
test_rpc_call_result.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy="
)
set_tests_properties
(
test_rpc_call_result PROPERTIES TIMEOUT
"120"
)
endif
()
add_subdirectory
(
fleet
)
add_subdirectory
(
fleet
)
add_subdirectory
(
multinode
)
add_subdirectory
(
multinode
)
python/paddle/fluid/tests/unittests/collective/py_server_test.py
0 → 100644
浏览文件 @
9025fddd
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
flask
import
Flask
,
request
,
jsonify
import
argparse
app
=
Flask
(
__name__
)
test_value
=
0.66943359375
@
app
.
route
(
'/run/predict'
,
methods
=
[
'POST'
])
def
echo
():
# Get the data from the request
request_json
=
request
.
json
# data = request_json['text']
# Echo the data back in the response
response
=
{
'result'
:
[
str
(
test_value
)]}
# Return the response in JSON format
return
jsonify
(
response
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--port'
,
type
=
int
,
required
=
True
,
help
=
'port'
)
parser
.
add_argument
(
'--ip'
,
type
=
str
,
required
=
False
,
default
=
'localhost'
,
help
=
'ip'
)
args
=
parser
.
parse_args
()
app
.
run
(
host
=
args
.
ip
,
port
=
args
.
port
)
python/paddle/fluid/tests/unittests/collective/test_rpc_call_result.py
0 → 100644
浏览文件 @
9025fddd
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
import
subprocess
import
unittest
import
os
def
rpc_test
(
use_ids
,
out_type
,
url
):
paddle
.
enable_static
()
MAX_SIZE_QUERY
=
18
RES_TYPE
=
out_type
with
open
(
"vocab.txt"
,
"w"
)
as
voc
:
voc
.
write
(
"ABC 0
\n
"
)
voc
.
write
(
"EFG 1
\n
"
)
voc
.
write
(
"HIG 2
\n
"
)
voc
.
write
(
"[<S>] 3
\n
"
)
voc
.
write
(
"[<N>] 4
\n
"
)
voc
.
write
(
"[<t>] 5
\n
"
)
voc
.
write
(
"[<T>] 6
\n
"
)
voc
.
write
(
"##good 7
\n
"
)
voc
.
write
(
"bad@@ 8
\n
"
)
voc
.
write
(
"@@badok 9
\n
"
)
voc
.
write
(
"你好 10
\n
"
)
voc
.
write
(
"haha 11
\n
"
)
voc
.
write
(
"##haha@@ 12
\n
"
)
voc
.
write
(
"[PAD] 13
\n
"
)
voc
.
write
(
"[gEnd] 14
\n
"
)
# network
in_query
=
fluid
.
data
(
name
=
'X'
,
shape
=
[
-
1
,
MAX_SIZE_QUERY
],
dtype
=
'int32'
)
req_ids
=
paddle
.
static
.
nn
.
rpc_call
(
in_query
,
url
,
"vocab.txt"
,
use_ids
,
)
out_data
,
out_succeed
=
paddle
.
static
.
nn
.
rpc_result
(
req_ids
,
RES_TYPE
)
paddle
.
static
.
Print
(
in_query
)
paddle
.
static
.
Print
(
req_ids
)
paddle
.
static
.
Print
(
out_data
.
astype
(
"float32"
))
query_tensor
=
np
.
array
(
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
0
,
1
,
2
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
0
,
1
,
2
,
14
],
]
).
astype
(
"int32"
)
# run
exe
=
fluid
.
Executor
(
fluid
.
CUDAPlace
(
0
))
exe
.
run
(
fluid
.
default_startup_program
())
for
_
in
range
(
1
):
succeed
,
data
,
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'X'
:
query_tensor
,
},
fetch_list
=
[
out_succeed
,
out_data
],
)
if
out_type
==
"str"
:
print
(
data
[
0
].
tobytes
().
decode
(
"utf-8"
,
"ignore"
))
else
:
print
(
data
[
0
])
class
RPCCallTest
(
unittest
.
TestCase
):
def
test_cases
(
self
):
ip
=
'localhost'
port
=
int
(
os
.
environ
.
get
(
"PADDLE_DIST_UT_PORT"
))
server_cmd
=
f
"python py_server_test.py --ip
{
ip
}
--port
{
port
}
"
with
open
(
f
"server.
{
port
}
.log"
,
"w"
)
as
output
:
process
=
subprocess
.
Popen
(
server_cmd
.
split
(),
stdout
=
output
,
stderr
=
output
)
for
uid
in
[
True
,
False
]:
for
otype
in
[
'str'
,
'float'
]:
try
:
rpc_test
(
uid
,
otype
,
f
"http://
{
ip
}
:
{
port
}
/run/predict"
)
except
:
process
.
kill
()
raise
RuntimeError
(
"rpc test error"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_rpc_call_result.sh
0 → 100644
浏览文件 @
9025fddd
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python test_rpc_call_result.py
python/paddle/fluid/tests/unittests/collective/testslist.csv
浏览文件 @
9025fddd
...
@@ -45,3 +45,4 @@ test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_
...
@@ -45,3 +45,4 @@ test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
test_rpc_call_result,linux,rocm;gpu,120,DIST,test_rpc_call_result.sh,1,,http_proxy=;https_proxy=,
python/paddle/static/nn/__init__.py
浏览文件 @
9025fddd
...
@@ -59,7 +59,9 @@ from ...fluid.layers.sequence_lod import sequence_scatter # noqa: F401
...
@@ -59,7 +59,9 @@ from ...fluid.layers.sequence_lod import sequence_scatter # noqa: F401
from
...fluid.layers.sequence_lod
import
sequence_enumerate
# noqa: F401
from
...fluid.layers.sequence_lod
import
sequence_enumerate
# noqa: F401
from
...fluid.layers.sequence_lod
import
sequence_reverse
# noqa: F401
from
...fluid.layers.sequence_lod
import
sequence_reverse
# noqa: F401
__all__
=
[
#noqa
from
.rpc_utils
import
rpc_call
,
rpc_result
__all__
=
[
# noqa
'fc'
,
'fc'
,
'batch_norm'
,
'batch_norm'
,
'embedding'
,
'embedding'
,
...
@@ -101,4 +103,6 @@ __all__ = [ #noqa
...
@@ -101,4 +103,6 @@ __all__ = [ #noqa
'sequence_enumerate'
,
'sequence_enumerate'
,
'sequence_reverse'
,
'sequence_reverse'
,
'StaticRNN'
,
'StaticRNN'
,
'rpc_call'
,
'rpc_result'
,
]
]
python/paddle/static/nn/rpc_utils.py
0 → 100644
浏览文件 @
9025fddd
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
fluid
import
paddle
class
IDGen
:
def
__init__
(
self
)
->
None
:
self
.
ids
=
{}
def
gen_name_with_idx
(
self
,
name
):
if
name
not
in
self
.
ids
:
self
.
ids
[
name
]
=
-
1
self
.
ids
[
name
]
+=
1
return
name
+
"_"
+
str
(
self
.
ids
[
name
])
def
__call__
(
self
,
name
)
->
str
:
return
self
.
gen_name_with_idx
(
name
)
id_gen
=
IDGen
()
def
rpc_call
(
src_ids
=
None
,
url
=
""
,
voc_path
=
""
,
cvt2str
=
True
):
request_id
=
(
fluid
.
default_main_program
()
.
block
(
0
)
.
create_var
(
name
=
id_gen
(
"rpc_request_id"
),
dtype
=
"int32"
,
shape
=
[
src_ids
.
shape
[
0
]],
persistable
=
False
,
stop_gradient
=
True
,
)
)
src_ids
=
src_ids
.
astype
(
"int32"
)
fluid
.
default_main_program
().
block
(
0
).
append_op
(
type
=
"rpc_call"
,
inputs
=
{
'X'
:
[
src_ids
],
},
outputs
=
{
"Out"
:
[
request_id
]},
attrs
=
{
"url"
:
url
,
"vocab_path"
:
voc_path
,
"use_ids"
:
not
cvt2str
,
"timeout"
:
3000
,
"retry"
:
100
,
},
)
return
request_id
def
rpc_result
(
request_ids
,
result_dtype
):
if
result_dtype
==
"float"
:
res
=
(
fluid
.
default_main_program
()
.
block
(
0
)
.
create_var
(
name
=
id_gen
(
"rpc_res"
),
dtype
=
"float32"
,
shape
=
[
request_ids
.
shape
[
0
]],
persistable
=
False
,
stop_gradient
=
True
,
)
)
elif
result_dtype
==
"str"
:
res
=
(
fluid
.
default_main_program
()
.
block
(
0
)
.
create_var
(
name
=
id_gen
(
"rpc_res"
),
dtype
=
"uint8"
,
shape
=
[
request_ids
.
shape
[
0
]],
persistable
=
False
,
stop_gradient
=
True
,
)
)
else
:
raise
ValueError
(
"result dtype must be one of str ot float"
)
success
=
(
fluid
.
default_main_program
()
.
block
(
0
)
.
create_var
(
name
=
id_gen
(
"rpc_success"
),
dtype
=
"bool"
,
shape
=
[
1
],
persistable
=
False
,
stop_gradient
=
True
,
)
)
fluid
.
default_main_program
().
block
(
0
).
append_op
(
type
=
"rpc_result"
,
inputs
=
{
"X"
:
[
request_ids
]},
outputs
=
{
"Out"
:
[
res
],
"succeed"
:
[
success
]},
attrs
=
{
"res_type"
:
result_dtype
},
)
return
res
,
success
python/paddle/tensor/stat.py
浏览文件 @
9025fddd
...
@@ -175,8 +175,8 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
...
@@ -175,8 +175,8 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
out
=
paddle
.
sum
((
x
-
u
)
**
2
,
axis
,
keepdim
=
keepdim
,
name
=
name
)
out
=
paddle
.
sum
((
x
-
u
)
**
2
,
axis
,
keepdim
=
keepdim
,
name
=
name
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
n
=
paddle
.
cast
(
paddle
.
numel
(
x
),
paddle
.
int64
)
/
paddle
.
cast
(
n
=
paddle
.
cast
(
paddle
.
numel
(
x
),
dtype
)
/
paddle
.
cast
(
paddle
.
numel
(
out
),
paddle
.
int64
paddle
.
numel
(
out
),
dtype
)
)
n
=
n
.
astype
(
dtype
)
n
=
n
.
astype
(
dtype
)
if
unbiased
:
if
unbiased
:
...
...
tools/codestyle/cpplint_pre_commit.hook
浏览文件 @
9025fddd
...
@@ -21,7 +21,8 @@ else
...
@@ -21,7 +21,8 @@ else
fi
fi
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for
file
in
$files
;
do
for
file
in
$files
;
do
if
[[
$file
=
~ ^
(
patches/.
*
)
]]
;
then
echo
$file
if
[[
$file
=
~ ^
(
patches/.
*
)
||
$file
=
~ ^
(
paddle/fluid/operators/collective/thirdparty/json.h
)
]]
;
then
continue
;
continue
;
else
else
cpplint
--filter
=
-readability
/fn_size,-build/include_what_you_use,-build/c++11,-whitespace/parens
$file
;
cpplint
--filter
=
-readability
/fn_size,-build/include_what_you_use,-build/c++11,-whitespace/parens
$file
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录