未验证 提交 9025fddd 编写于 作者: W Wen Sun 提交者: GitHub

Add rpc ops to fetch data from remote service (#50220)

上级 0699afb1
......@@ -18,6 +18,10 @@ repos:
rev: v4.1.0
hooks:
- id: check-added-large-files
exclude: |
(?x)^(
paddle/fluid/operators/collective/thirdparty/json.h
)$
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
......@@ -35,7 +39,8 @@ repos:
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
exclude: |
(?x)^(
paddle/fluid/distributed/ps/thirdparty/round_robin.h
paddle/fluid/distributed/ps/thirdparty/round_robin.h|
paddle/fluid/operators/collective/thirdparty/json.h
)$
- repo: local
hooks:
......@@ -62,7 +67,8 @@ repos:
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: |
(?x)^(
paddle/utils/.*
paddle/utils/.*|
paddle/fluid/operators/collective/thirdparty/json.h
)$
- repo: local
hooks:
......
......@@ -96,7 +96,7 @@ if(NOT APPLE AND NOT WIN32)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
if(WITH_PSLIB OR WITH_DISTRIBUTE)
set(CMAKE_CXX_LINK_EXECUTABLE
"${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt -lz -lssl")
"${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt -lz -lssl -lcrypto")
else()
set(CMAKE_CXX_LINK_EXECUTABLE
"${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")
......
......@@ -424,6 +424,18 @@ if(WITH_PSCORE)
list(APPEND third_party_deps extern_rocksdb)
endif()
if(WITH_DISTRIBUTE
AND NOT WITH_PSLIB
AND NOT WITH_PSCORE)
include(external/snappy)
list(APPEND third_party_deps extern_snappy)
include(external/leveldb)
list(APPEND third_party_deps extern_leveldb)
include(external/brpc)
list(APPEND third_party_deps extern_brpc)
endif()
if(WITH_XBYAK)
include(external/xbyak) # download, build, install xbyak
list(APPEND third_party_deps extern_xbyak)
......
add_subdirectory(auto_parallel)
add_subdirectory(collective)
add_subdirectory(store)
add_subdirectory(fleet_executor)
if(WITH_PYTHON)
py_proto_compile(ps_py_proto SRCS the_one_ps.proto)
add_custom_target(
......@@ -29,7 +30,6 @@ if(WITH_PYTHON)
endif()
if(NOT WITH_PSCORE)
add_subdirectory(fleet_executor)
return()
endif()
......@@ -47,4 +47,3 @@ add_subdirectory(common)
add_subdirectory(ps)
add_subdirectory(test)
add_subdirectory(index_dataset)
add_subdirectory(fleet_executor)
......@@ -6,7 +6,7 @@ proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_ARM_BRPC)
set(BRPC_DEPS arm_brpc snappy gflags glog)
elseif(WITH_DISTRIBUTE AND WITH_PSCORE)
elseif(WITH_DISTRIBUTE)
set(BRPC_DEPS
brpc
ssl
......
......@@ -73,7 +73,7 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
server_.Stop(1000);
server_.Join();
#endif
......@@ -94,7 +94,7 @@ bool MessageBus::Send(int64_t dst_rank,
true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) {
++retry_time;
......@@ -179,7 +179,7 @@ void MessageBus::ListenPort() {
LOG(INFO) << "No need listen to port since training on single card.";
return;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE),
......@@ -209,7 +209,7 @@ void MessageBus::ListenPort() {
#endif
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank);
......
......@@ -20,7 +20,7 @@
#include <thread>
#include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
......@@ -63,7 +63,7 @@ class MessageBus final {
const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
// send the message inter rank (dst is different rank with src)
bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message);
......@@ -79,7 +79,7 @@ class MessageBus final {
// the ip needs to be listened
std::string addr_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
MessageServiceImpl message_service_;
// brpc server
brpc::Server server_;
......
......@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
......
......@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#if defined(PADDLE_WITH_DISTRIBUTE)
#pragma once
#include "brpc/server.h"
......
......@@ -59,9 +59,7 @@ cc_test(
scope
device_context)
if(WITH_DISTRIBUTE
AND WITH_PSCORE
AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(
interceptor_ping_pong_with_brpc_test.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -584,16 +584,18 @@ if(WITH_PYTHON)
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_target(
fleet_executor_proto_init ALL
DEPENDS fleet_proto_init fleet_executor_desc_py_proto
COMMAND
cp
${PADDLE_BINARY_DIR}/paddle/fluid/distributed/fleet_executor/fleet_executor_*.py
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMENT
"Copy generated python proto into directory paddle/distributed/fleet/proto."
)
if(NOT WITH_ROCM)
add_custom_target(
fleet_executor_proto_init ALL
DEPENDS fleet_proto_init fleet_executor_desc_py_proto
COMMAND
cp
${PADDLE_BINARY_DIR}/paddle/fluid/distributed/fleet_executor/fleet_executor_*.py
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMENT
"Copy generated python proto into directory paddle/distributed/fleet/proto."
)
endif()
else()
string(REPLACE "/" "\\" proto_dstpath
"${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/")
......
......@@ -30,9 +30,16 @@ register_operators(
c_gen_hccl_id_op
gen_hccl_id_op
c_gen_cncl_id_op
rpc_call_op
rpc_result_op
DEPS
${COLLECTIVE_DEPS})
if(WITH_DISTRIBUTE)
op_library(rpc_call_op DEPS rpc_utils ${COLLECTIVE_DEPS})
op_library(rpc_result_op DEPS rpc_utils ${COLLECTIVE_DEPS})
endif()
if(WITH_NCCL OR WITH_RCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/rpc_call_op.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class RpcCallOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
class RpcCallOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Src words' ids.");
AddOutput("Out", "(Tensor) Request id.");
AddAttr<std::string>("url", "URL.").SetDefault({});
AddAttr<std::string>("vocab_path", "Vocab's absolute path.").SetDefault("");
AddAttr<bool>("use_ids", "If true, use ids directly.").SetDefault(true);
AddAttr<int>("timeout", "rpc connection timeout ms").SetDefault(3000);
AddAttr<int>("retry", "rpc connection retry time").SetDefault(100);
AddComment(R"DOC(
Rpc Call Operator
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rpc_call, ops::RpcCallOp, ops::RpcCallOpMaker);
REGISTER_OP_CPU_KERNEL(rpc_call,
ops::RpcCallOpKernel<int>,
ops::RpcCallOpKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(rpc_call,
ops::RpcCallOpKernel<int>,
ops::RpcCallOpKernel<int64_t>);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <brpc/channel.h>
#include <fstream>
#include <memory>
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/collective/thirdparty/json.h"
#include "paddle/fluid/platform/rpc_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace operators {
#define DATA_STRLIST 0
/*
{"data": ["你好"]}
*/
#define TEXT_STR 1
/*
{"text": "nihao"}
*/
using json = nlohmann::json;
// payload builders
template <typename T = int64_t>
static inline std::string BuildIdsPayload(const std::vector<T>& src_ids) {
json payload = {{"ids", src_ids}}; // => {"ids": [1, 2, 3, ...]}
return payload.dump();
}
static inline std::string BuildStrPayload(const std::string& query,
int build_way) {
json payload;
switch (build_way) {
case DATA_STRLIST:
payload = {{"data", {query}}}; //=> {"data": [query]}
break;
case TEXT_STR:
payload = {{"text", query}}; //=> {"text": query}
break;
default:
break;
}
return payload.dump();
}
template <typename T = int64_t>
static inline std::string BuildPayload(const std::string& service,
const std::vector<T>& src_ids) {
if (service == "ids") {
return BuildIdsPayload(src_ids);
} else if (service == "str") {
const std::string query =
platform::RpcTokenizer::Instance().GetWordsFromIds(src_ids);
return BuildStrPayload(query, TEXT_STR);
} else {
PADDLE_THROW(platform::errors::InvalidArgument("Unknown service."));
}
}
// req & res handlers
static inline void HandleServiceRequest(brpc::Controller* ctrl,
int request_id,
const std::string& payload) {
ctrl->request_attachment().append(payload);
VLOG(3) << "Request id " << request_id << "payload size:" << payload.size();
VLOG(3) << "Request id " << request_id << " payload: " << payload;
}
static inline void HandleServiceResponse(
brpc::Controller* ctrl,
int request_id,
std::shared_ptr<bthread::CountdownEvent> event) {
// make sure the controller will be deleted
std::unique_ptr<brpc::Controller> ctrl_guard(ctrl);
auto& rpc_store = platform::RpcRequestStore::Instance();
if (ctrl->Failed()) {
rpc_store.InsertErrorCode(request_id, ctrl->ErrorCode());
PADDLE_THROW(platform::errors::Unavailable(
"Request id %s failed: access url error. error code: %d, http code: %d",
request_id,
ctrl->ErrorCode(),
ctrl->http_response().status_code()));
} else {
const std::string res = ctrl->response_attachment().to_string();
rpc_store.InsertErrorCode(request_id, 0);
rpc_store.InsertResponse(request_id, res);
}
// try to notify result op
event->signal();
}
static int send_sequence(const framework::ExecutionContext& ctx,
const std::string& service,
const phi::DenseTensor& src_ids_tensor,
const std::string& url,
const int& timeout = 3000,
const int& retry = 100) {
std::vector<int> src_ids_vec;
framework::TensorToVector(src_ids_tensor, ctx.device_context(), &src_ids_vec);
const std::string payload = BuildPayload(service, src_ids_vec);
int request_id =
platform::RpcCommContext::RpcSend(url,
payload,
&HandleServiceRequest,
&HandleServiceResponse,
brpc::HttpMethod::HTTP_METHOD_POST,
timeout,
retry);
VLOG(3) << "Request id " << request_id << " url: " << url;
VLOG(3) << "Request id " << request_id << " payload: " << payload;
return request_id;
}
template <typename T>
class RpcCallOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// url, assume num of urls is limited
const std::string url = ctx.Attr<std::string>("url");
// payload
auto src_ids_tensor = ctx.Input<phi::DenseTensor>("X");
auto x_dims = src_ids_tensor->dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2,
platform::errors::PreconditionNotMet(
"The input src ids' dim size must be 2. However the dim is %d",
x_dims.size()));
std::vector<int> request_ids(x_dims[0]);
bool use_ids = ctx.Attr<bool>("use_ids");
std::string service;
if (use_ids) {
service = "ids";
} else {
// init tokenizer
auto vocab_path = ctx.Attr<std::string>("vocab_path");
std::unordered_map<std::string, std::string> special;
platform::RpcTokenizer::Instance().Init(vocab_path, special);
service = "str";
}
int timeout = ctx.Attr<int>("timeout");
int retry = ctx.Attr<int>("retry");
for (auto i = 0; i < x_dims[0]; i++) {
request_ids[i] = send_sequence(
ctx, service, src_ids_tensor->Slice(i, i + 1), url, timeout, retry);
}
auto* out = ctx.Output<phi::DenseTensor>("Out");
out->Resize({static_cast<int64_t>(request_ids.size())});
ctx.device_context().Alloc<int>(out);
framework::TensorFromVector(request_ids, ctx.device_context(), out);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/rpc_result_op.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace operators {
class RpcResultOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
class RpcResultOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Request id.");
AddOutput("Out", "(Tensor) Response from service.");
AddOutput("succeed", "Request status, true means succeed.");
AddAttr<std::string>("res_type", "Result type returns.")
.SetDefault("float");
AddComment(R"DOC(
Rpc Result Operator
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rpc_result,
ops::RpcResultOp,
ops::RpcResultOpMaker);
REGISTER_OP_CPU_KERNEL(rpc_result, ops::RpcResultOpKernel<int>);
REGISTER_OP_CUDA_KERNEL(rpc_result, ops::RpcResultOpKernel<int>);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/collective/thirdparty/json.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/rpc_utils.h"
namespace paddle {
namespace operators {
using json = nlohmann::json;
#define PARSE_DIRECT_FLOAT 0
/*
1.23
*/
#define PARSE_RESULT_FLOAT 1
/*
{"result": ["1.23"]}
*/
static inline std::vector<float> ParseFloatResponse(const std::string& response,
int parse_way) {
auto obj = json::parse(response);
switch (parse_way) {
case PARSE_RESULT_FLOAT: {
auto res = obj["result"][0].get<std::string>();
return {std::stof(res, nullptr)};
}
case PARSE_DIRECT_FLOAT:
return {obj.get<float>()};
default:
break;
}
return {static_cast<float>(0)};
}
static inline std::vector<uint8_t> ParseStrResponse(
const std::string& response) {
const std::string res = json::parse(response).dump();
return std::vector<uint8_t>(res.begin(), res.end());
}
static std::vector<uint8_t> get_str_response(const int& request_id) {
// wait for call op's event notification
auto& rpc_store = platform::RpcRequestStore::Instance();
auto event = rpc_store.GetEvent(request_id);
int err_code = rpc_store.GetErrorCode(request_id);
bool ok = event->wait() == 0 && err_code == 0;
if (ok) {
const std::string& resp = rpc_store.GetResponse(request_id);
VLOG(3) << "Request id " << request_id << " raw response: " << resp;
VLOG(3) << "Request id " << request_id;
// auto out_ = const_cast<phi::DenseTensor&>(out);
auto out_vector = ParseStrResponse(resp);
return out_vector;
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Request %s failed with error code %s.", request_id, err_code));
}
}
static std::vector<float> get_float_response(const int& request_id) {
// wait for call op's event notification
auto& rpc_store = platform::RpcRequestStore::Instance();
auto event = rpc_store.GetEvent(request_id);
int err_code = rpc_store.GetErrorCode(request_id);
bool ok = event->wait() == 0 && err_code == 0;
if (ok) {
const std::string& resp = rpc_store.GetResponse(request_id);
VLOG(3) << "Request id " << request_id << " raw response: " << resp;
VLOG(3) << "Request id " << request_id;
// auto out_ = const_cast<phi::DenseTensor&>(out);
auto out_vector = ParseFloatResponse(resp, PARSE_RESULT_FLOAT);
return out_vector;
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Request %s failed with error code %s.", request_id, err_code));
}
}
template <typename T>
class RpcResultOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* request_id_tensor = ctx.Input<phi::DenseTensor>("X");
std::vector<int> request_id_tensor_vec;
framework::TensorToVector(
*request_id_tensor, ctx.device_context(), &request_id_tensor_vec);
auto* out = ctx.Output<phi::DenseTensor>("Out");
const std::string res_type = ctx.Attr<std::string>("res_type");
VLOG(3) << "out dims: " << out->dims().to_str()
<< "numel: " << out->numel();
if (res_type == "str") {
ctx.device_context().Alloc<uint8_t>(out);
} else if (res_type == "float") {
ctx.device_context().Alloc<float>(out);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown result type. error type: %s", res_type.c_str()));
}
VLOG(3) << "out dims: " << out->dims().to_str();
std::vector<std::vector<uint8_t>> uint8_vec;
std::vector<std::vector<float>> float_vec;
int64_t max_size = -1;
for (auto i = 0; i < request_id_tensor->dims()[0]; i++) {
if (res_type == "float") {
auto vec = get_float_response(request_id_tensor_vec[i]);
max_size = std::max(max_size, static_cast<int64_t>(vec.size()));
float_vec.emplace_back(vec);
} else if (res_type == "str") {
auto vec = get_str_response(request_id_tensor_vec[i]);
uint8_vec.emplace_back(vec);
max_size = std::max(max_size, static_cast<int64_t>(vec.size()));
PADDLE_ENFORCE_LE(
max_size,
100 * 1024 * 1024,
platform::errors::Unavailable("to many string data, exceed 100MB"));
}
}
out->Resize({request_id_tensor->dims()[0], max_size});
if (res_type == "str") {
ctx.device_context().Alloc<uint8_t>(out);
for (size_t i = 0; i < uint8_vec.size(); i++) {
phi::DenseTensor out_ = out->Slice(i, i + 1);
for (int k = uint8_vec[i].size(); k < max_size; k++) {
uint8_vec[i].emplace_back(static_cast<uint8_t>(0));
}
framework::TensorFromVector(uint8_vec[i], ctx.device_context(), &out_);
}
} else if (res_type == "float") {
ctx.device_context().Alloc<float>(out);
for (size_t i = 0; i < float_vec.size(); i++) {
phi::DenseTensor out_ = out->Slice(i, i + 1);
framework::TensorFromVector(float_vec[i], ctx.device_context(), &out_);
}
}
auto* succeed = ctx.Output<phi::DenseTensor>("succeed");
ctx.device_context().Alloc<bool>(succeed);
std::vector<bool> succeed_wrapper{true};
framework::TensorFromVector(succeed_wrapper, ctx.device_context(), succeed);
}
};
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -223,6 +223,24 @@ cc_library(
phi_device_context
generator)
if(WITH_DISTRIBUTE)
set(BRPC_DEPS
brpc
ssl
crypto
protobuf
zlib
leveldb
snappy
gflags
glog)
cc_library(
rpc_utils
SRCS rpc_utils.cc
DEPS enforce ${BRPC_DEPS})
endif()
cc_library(
collective_helper
SRCS collective_helper.cc gen_comm_id_helper.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/rpc_utils.h"
#include <algorithm>
#include <fstream>
#include <regex>
#include <sstream>
#include <unordered_set>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace platform {
// globals
static std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
// utils
static inline bool StartsWith(const std::string& str,
const std::string& prefix) {
return str.substr(0, prefix.length()) == prefix;
}
static inline bool EndsWith(const std::string& str, const std::string& suffix) {
return str.length() >= suffix.length() &&
str.substr(str.length() - suffix.length()) == suffix;
}
static inline std::string Replace(const std::string& str,
const std::string& old_str,
const std::string& new_str) {
if (old_str == new_str) {
return str;
}
std::stringstream ss;
size_t start_pos = 0;
size_t pos = str.find(old_str, start_pos);
while (pos != std::string::npos) {
ss << str.substr(start_pos, pos - start_pos) << new_str;
start_pos = pos + old_str.size();
pos = str.find(old_str, start_pos);
}
ss << str.substr(start_pos);
return ss.str();
}
static inline bool IsChineseChar(wchar_t c) {
return (c >= 0x4E00 && c <= 0x9FFF) || (c >= 0x3400 && c <= 0x4DBF) ||
(c >= 0x20000 && c <= 0x2A6DF) || (c >= 0x2A700 && c <= 0x2B73F) ||
(c >= 0x2B740 && c <= 0x2B81F) || (c >= 0x2B820 && c <= 0x2CEAF) ||
(c >= 0xF900 && c <= 0xFAFF) || (c >= 0x2F800 && c <= 0x2FA1F);
}
static inline bool IsChinesePunct(wchar_t c) {
std::unordered_set<wchar_t> puncts = {
L'!', L'?', L'。', L'。', L'"', L'#', L'$', L'%', L'&', L''',
L'(', L')', L'*', L'+', L',', L'-', L'/', L':', L';', L'<',
L'=', L'>', L'@', L'[', L'\', L']', L'^', L'_', L'`', L'{',
L'|', L'}', L'~', L'⦅', L'⦆', L'「', L'」', L'、', L'、', L'〃',
L'》', L'「', L'」', L'『', L'』', L'【', L'】', L'〔', L'〕', L'〖',
L'〗', L'〘', L'〙', L'〚', L'〛', L'〜', L'〝', L'〞', L'〟', L'〰',
L'〾', L'〿', L'–', L'—', L'“', L'”', L'‘', L'’'};
return puncts.count(c);
}
static inline int GetCharBytes(uint8_t byte) {
if ((byte & 0x80) == 0) {
return 1;
} else if ((byte & 0xE0) == 0xC0) {
return 2;
} else if ((byte & 0xF0) == 0xE0) {
return 3;
} else if ((byte & 0xF8) == 0xF0) {
return 4;
} else {
return -1;
}
}
static inline bool IsValidContinuationByte(uint8_t byte) {
// check if the byte starts with 10
return (byte & 0xC0) == 0x80;
}
static inline uint8_t GetByteFromHex(const std::string& token) {
auto num_str = paddle::string::split_string(token, "_")[1];
num_str = num_str.substr(0, num_str.size() - 1);
return static_cast<uint8_t>(std::stoi(num_str, nullptr, 16));
}
// RpcTokenizer
void RpcTokenizer::Init(const std::string& path) {
if (path_ == path) {
return;
}
std::ifstream vocab_file(path);
std::string word;
int id;
while (vocab_file >> word >> id) {
ids_to_words_.emplace(id, word);
words_to_ids_.emplace(converter.from_bytes(word), id);
}
// update members
path_ = path;
}
void RpcTokenizer::Init(
const std::string& path,
const std::unordered_map<std::string, std::string>& special_set) {
if (path_ == path) {
return;
}
Init(path);
SetSpecialSet(special_set);
}
std::string RpcTokenizer::GetRecoveredToken(const std::vector<uint8_t>& bytes) {
std::string res;
int n = bytes.size();
int i = 0;
while (i < n) {
int sz = 0;
while ((sz = GetCharBytes(bytes[i])) == -1) {
++i;
}
if (i + sz < n) {
std::vector<uint8_t> valid_bytes;
valid_bytes.emplace_back(bytes[i]);
for (int j = 1; j < sz; ++j) {
if (!IsValidContinuationByte(bytes[i])) {
break;
}
valid_bytes.emplace_back(bytes[i]);
++i;
}
if (valid_bytes.size() == static_cast<size_t>(sz)) {
res += std::string(valid_bytes.begin(), valid_bytes.end());
}
}
++i;
}
return res;
}
std::vector<std::string> RpcTokenizer::RecoverBFBTokens(
const std::vector<std::string>& tokens) {
std::vector<std::string> new_tokens;
std::vector<uint8_t> tmp_bytes;
for (const auto& token : tokens) {
if (StartsWith(token, "[BFB")) {
tmp_bytes.emplace_back(GetByteFromHex(token));
} else {
if (!tmp_bytes.empty()) {
// since there may be illegal bytes, we need this function
// if all bytes are legal, we can simply use string constructor
const std::string recovered_token = GetRecoveredToken(tmp_bytes);
if (!recovered_token.empty()) {
new_tokens.emplace_back(recovered_token);
}
}
if (token != "[UNK]") {
new_tokens.emplace_back(token);
}
tmp_bytes.clear();
}
}
if (!tmp_bytes.empty()) {
const std::string recovered_token = GetRecoveredToken(tmp_bytes);
if (!recovered_token.empty()) {
new_tokens.emplace_back(recovered_token);
}
}
return new_tokens;
}
std::vector<std::string> RpcTokenizer::PostProcess(
const std::vector<std::string>& tokens,
const WordToIdMap& vocab,
bool aggressive_break,
const std::string& stop_token) {
std::unordered_set<std::string> break_words;
if (aggressive_break) {
break_words = {"[END]", "[gEND]", "[<S>]", "[UNK]", "[CLS]"};
} else {
break_words = {"[END]", "[gEND]"};
}
static const std::unordered_map<std::string, std::string> replace_words{
{"[<S>]", " "},
{"[<N>]", "\n"},
{"[<T>]", "\t"},
{"[<t>]", " "},
};
std::vector<std::string> new_text;
auto words = RecoverBFBTokens(tokens);
for (auto& word : words) {
if (break_words.count(word) || word == stop_token) {
break;
}
if (word.empty() || word == "[PAD]") {
continue;
}
if (replace_words.count(word)) {
new_text.emplace_back(replace_words.at(word));
continue;
}
auto unicode_word = converter.from_bytes(word);
bool is_chinese_char = IsChineseChar(unicode_word[0]);
bool is_chinese_punct = IsChinesePunct(unicode_word[0]);
if (is_chinese_char || is_chinese_punct || vocab.count(unicode_word) == 0) {
if (!new_text.empty() && EndsWith(new_text.back(), "@@")) {
auto& last_word = new_text.back();
last_word = Replace(last_word, "@@", "");
}
new_text.emplace_back(word);
} else if (!StartsWith(word, "##")) {
if (!new_text.empty() && EndsWith(new_text.back(), "@@")) {
auto& last_word = new_text.back();
last_word = Replace(last_word, "@@", "");
new_text.emplace_back(word);
} else if (!new_text.empty() && EndsWith(new_text.back(), "\n")) {
new_text.emplace_back(word);
} else {
if (!new_text.empty() && !new_text.back().empty() &&
IsChineseChar(converter.from_bytes(new_text.back())[0])) {
new_text.emplace_back(word);
} else {
if (!new_text.empty()) {
new_text.emplace_back(" ");
}
new_text.emplace_back(word);
}
}
} else {
if (!new_text.empty() && EndsWith(new_text.back(), "@@")) {
auto& last_word = new_text.back();
last_word = last_word.substr(0, last_word.size() - 2);
}
new_text.emplace_back(Replace(word, "##", ""));
}
}
if (!new_text.empty()) {
auto& last_word = new_text.back();
last_word = Replace(last_word, "@@", "");
}
return new_text;
}
int RpcCommContext::RpcSend(
const std::string& url,
const std::string& query,
void (*request_handler)(brpc::Controller*, int, const std::string&),
void (*response_handler)(brpc::Controller*,
int,
std::shared_ptr<bthread::CountdownEvent>),
brpc::HttpMethod http_method,
int timeout_ms,
int max_retry) {
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = "http";
options.timeout_ms = timeout_ms;
options.max_retry = max_retry;
PADDLE_ENFORCE_EQ(
channel.Init(url.c_str(), /*load_balancer*/ "", &options),
0,
phi::errors::Unavailable("Rpc send failed: init brpc channel error."));
auto& rpc_store = RpcRequestStore::Instance();
int request_id = rpc_store.GetRequestId();
auto event = std::make_shared<bthread::CountdownEvent>();
RpcRequestStore::Instance().InsertEvent(request_id, event);
// if req is async, controller should be on heap to avoid deleting
auto* ctrl = new brpc::Controller();
ctrl->http_request().uri() = url.c_str();
ctrl->http_request().set_method(http_method);
ctrl->http_request().SetHeader("Content-Type", "application/json");
request_handler(ctrl, request_id, query);
channel.CallMethod(
nullptr,
ctrl,
nullptr,
nullptr,
brpc::NewCallback(response_handler, ctrl, request_id, event));
return request_id;
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <brpc/channel.h>
#include <bthread/countdown_event.h>
#include <atomic>
#include <codecvt>
#include <locale>
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/utils/string/string_helper.h"
namespace paddle {
namespace platform {
using WordToIdMap = std::unordered_map<std::wstring, int64_t>;
using IdToWordMap = std::unordered_map<int64_t, std::string>;
class RpcTokenizer {
public:
static RpcTokenizer& Instance() {
static RpcTokenizer instance;
return instance;
}
void Init(const std::string& path);
void Init(const std::string& path,
const std::unordered_map<std::string, std::string>& special_set);
void SetSpecialSet(
const std::unordered_map<std::string, std::string>& special_set) {
special_set_ = special_set;
}
bool Contains(int64_t id) { return ids_to_words_.count(id) > 0; }
// NOTE: an exception will be raised if id not exist
std::string GetWordFromId(int64_t id) {
auto q = ids_to_words_.at(id);
if (special_set_.count(q) == 1) {
return special_set_.at(q);
} else {
return q;
}
}
template <typename T = int64_t>
std::string GetWordsFromIds(const std::vector<T>& ids,
bool aggressive_break = false,
const std::string& stop_token = "[gEND]") {
std::vector<std::string> tokens;
for (auto id : ids) {
if (!Contains(id)) {
continue;
}
tokens.emplace_back(GetWordFromId(id));
}
return paddle::string::join_strings(
PostProcess(tokens, words_to_ids_, aggressive_break, stop_token), "");
}
// NOTE: an exception will be raised if word not exist
int64_t GetIdFromWord(const std::wstring& word) {
return words_to_ids_.at(word);
}
private:
std::string GetRecoveredToken(const std::vector<uint8_t>& bytes);
std::vector<std::string> RecoverBFBTokens(
const std::vector<std::string>& tokens);
std::vector<std::string> PostProcess(
const std::vector<std::string>& tokens,
const WordToIdMap& vocab,
bool aggressive_break = false,
const std::string& stop_token = "[gEND]");
private:
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter_;
std::string path_;
IdToWordMap ids_to_words_;
WordToIdMap words_to_ids_;
std::unordered_map<std::string, std::string> special_set_;
};
class RpcRequestStore {
public:
static RpcRequestStore& Instance() {
static RpcRequestStore instance;
return instance;
}
int GetRequestId() {
if (request_id_ == INT32_MAX) {
request_id_ = 0;
} else {
++request_id_;
}
return request_id_;
}
std::shared_ptr<bthread::CountdownEvent> GetEvent(int request_id) {
return id_to_event_map_[request_id];
}
int GetErrorCode(int request_id) { return id_to_err_map_[request_id]; }
std::string GetResponse(int request_id) {
return id_to_resp_map_[request_id];
}
void InsertEvent(int request_id,
const std::shared_ptr<bthread::CountdownEvent>& event) {
if (request_id == 0) {
LOG(WARNING) << "Total num of requests have exceeded int limits.";
}
id_to_event_map_.emplace(request_id, event);
}
void InsertErrorCode(int request_id, int error_code) {
if (request_id == 0) {
LOG(WARNING) << "Total num of requests have exceeded int limits.";
}
id_to_err_map_.emplace(request_id, error_code);
}
void InsertResponse(int request_id, const std::string& resp) {
if (request_id == 0) {
LOG(WARNING) << "Total num of requests have exceeded int limits.";
}
id_to_resp_map_.emplace(request_id, resp);
}
private:
std::atomic<int> request_id_;
std::unordered_map<int, std::shared_ptr<bthread::CountdownEvent>>
id_to_event_map_;
std::unordered_map<int, int> id_to_err_map_;
std::unordered_map<int, std::string> id_to_resp_map_;
};
struct RpcCommContext {
static int RpcSend(
const std::string& url,
const std::string& query,
void (*request_handler)(brpc::Controller*, int, const std::string&),
void (*response_handler)(brpc::Controller*,
int,
std::shared_ptr<bthread::CountdownEvent>),
brpc::HttpMethod http_method = brpc::HttpMethod::HTTP_METHOD_POST,
int timeout_ms = 10000,
int max_retry = 3);
};
} // namespace platform
} // namespace paddle
......@@ -379,5 +379,16 @@ if((WITH_ROCM OR WITH_GPU) AND (LINUX))
"PADDLE_DIST_UT_PORT=21532;http_proxy=;https_proxy=")
set_tests_properties(test_world_size_and_rank PROPERTIES TIMEOUT "120")
endif()
if((WITH_ROCM OR WITH_GPU) AND (LINUX))
bash_test_modules(
test_rpc_call_result
START_BASH
test_rpc_call_result.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy=")
set_tests_properties(test_rpc_call_result PROPERTIES TIMEOUT "120")
endif()
add_subdirectory(fleet)
add_subdirectory(multinode)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flask import Flask, request, jsonify
import argparse
app = Flask(__name__)
test_value = 0.66943359375
@app.route('/run/predict', methods=['POST'])
def echo():
# Get the data from the request
request_json = request.json
# data = request_json['text']
# Echo the data back in the response
response = {'result': [str(test_value)]}
# Return the response in JSON format
return jsonify(response)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, required=True, help='port')
parser.add_argument(
'--ip', type=str, required=False, default='localhost', help='ip'
)
args = parser.parse_args()
app.run(host=args.ip, port=args.port)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import subprocess
import unittest
import os
def rpc_test(use_ids, out_type, url):
paddle.enable_static()
MAX_SIZE_QUERY = 18
RES_TYPE = out_type
with open("vocab.txt", "w") as voc:
voc.write("ABC 0\n")
voc.write("EFG 1\n")
voc.write("HIG 2\n")
voc.write("[<S>] 3\n")
voc.write("[<N>] 4\n")
voc.write("[<t>] 5\n")
voc.write("[<T>] 6\n")
voc.write("##good 7\n")
voc.write("bad@@ 8\n")
voc.write("@@badok 9\n")
voc.write("你好 10\n")
voc.write("haha 11\n")
voc.write("##haha@@ 12\n")
voc.write("[PAD] 13\n")
voc.write("[gEnd] 14\n")
# network
in_query = fluid.data(name='X', shape=[-1, MAX_SIZE_QUERY], dtype='int32')
req_ids = paddle.static.nn.rpc_call(
in_query,
url,
"vocab.txt",
use_ids,
)
out_data, out_succeed = paddle.static.nn.rpc_result(req_ids, RES_TYPE)
paddle.static.Print(in_query)
paddle.static.Print(req_ids)
paddle.static.Print(out_data.astype("float32"))
query_tensor = np.array(
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 14],
]
).astype("int32")
# run
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
for _ in range(1):
succeed, data, = exe.run(
fluid.default_main_program(),
feed={
'X': query_tensor,
},
fetch_list=[out_succeed, out_data],
)
if out_type == "str":
print(data[0].tobytes().decode("utf-8", "ignore"))
else:
print(data[0])
class RPCCallTest(unittest.TestCase):
def test_cases(self):
ip = 'localhost'
port = int(os.environ.get("PADDLE_DIST_UT_PORT"))
server_cmd = f"python py_server_test.py --ip {ip} --port {port}"
with open(f"server.{port}.log", "w") as output:
process = subprocess.Popen(
server_cmd.split(), stdout=output, stderr=output
)
for uid in [True, False]:
for otype in ['str', 'float']:
try:
rpc_test(uid, otype, f"http://{ip}:{port}/run/predict")
except:
process.kill()
raise RuntimeError("rpc test error")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python test_rpc_call_result.py
......@@ -45,3 +45,4 @@ test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
test_rpc_call_result,linux,rocm;gpu,120,DIST,test_rpc_call_result.sh,1,,http_proxy=;https_proxy=,
......@@ -59,7 +59,9 @@ from ...fluid.layers.sequence_lod import sequence_scatter # noqa: F401
from ...fluid.layers.sequence_lod import sequence_enumerate # noqa: F401
from ...fluid.layers.sequence_lod import sequence_reverse # noqa: F401
__all__ = [ #noqa
from .rpc_utils import rpc_call, rpc_result
__all__ = [ # noqa
'fc',
'batch_norm',
'embedding',
......@@ -101,4 +103,6 @@ __all__ = [ #noqa
'sequence_enumerate',
'sequence_reverse',
'StaticRNN',
'rpc_call',
'rpc_result',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
import paddle
class IDGen:
def __init__(self) -> None:
self.ids = {}
def gen_name_with_idx(self, name):
if name not in self.ids:
self.ids[name] = -1
self.ids[name] += 1
return name + "_" + str(self.ids[name])
def __call__(self, name) -> str:
return self.gen_name_with_idx(name)
id_gen = IDGen()
def rpc_call(src_ids=None, url="", voc_path="", cvt2str=True):
request_id = (
fluid.default_main_program()
.block(0)
.create_var(
name=id_gen("rpc_request_id"),
dtype="int32",
shape=[src_ids.shape[0]],
persistable=False,
stop_gradient=True,
)
)
src_ids = src_ids.astype("int32")
fluid.default_main_program().block(0).append_op(
type="rpc_call",
inputs={
'X': [src_ids],
},
outputs={"Out": [request_id]},
attrs={
"url": url,
"vocab_path": voc_path,
"use_ids": not cvt2str,
"timeout": 3000,
"retry": 100,
},
)
return request_id
def rpc_result(request_ids, result_dtype):
if result_dtype == "float":
res = (
fluid.default_main_program()
.block(0)
.create_var(
name=id_gen("rpc_res"),
dtype="float32",
shape=[request_ids.shape[0]],
persistable=False,
stop_gradient=True,
)
)
elif result_dtype == "str":
res = (
fluid.default_main_program()
.block(0)
.create_var(
name=id_gen("rpc_res"),
dtype="uint8",
shape=[request_ids.shape[0]],
persistable=False,
stop_gradient=True,
)
)
else:
raise ValueError("result dtype must be one of str ot float")
success = (
fluid.default_main_program()
.block(0)
.create_var(
name=id_gen("rpc_success"),
dtype="bool",
shape=[1],
persistable=False,
stop_gradient=True,
)
)
fluid.default_main_program().block(0).append_op(
type="rpc_result",
inputs={"X": [request_ids]},
outputs={"Out": [res], "succeed": [success]},
attrs={"res_type": result_dtype},
)
return res, success
......@@ -175,8 +175,8 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
out = paddle.sum((x - u) ** 2, axis, keepdim=keepdim, name=name)
dtype = x.dtype
n = paddle.cast(paddle.numel(x), paddle.int64) / paddle.cast(
paddle.numel(out), paddle.int64
n = paddle.cast(paddle.numel(x), dtype) / paddle.cast(
paddle.numel(out), dtype
)
n = n.astype(dtype)
if unbiased:
......
......@@ -21,7 +21,8 @@ else
fi
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $files; do
if [[ $file =~ ^(patches/.*) ]]; then
echo $file
if [[ $file =~ ^(patches/.*) || $file =~ ^(paddle/fluid/operators/collective/thirdparty/json.h) ]]; then
continue;
else
cpplint --filter=-readability/fn_size,-build/include_what_you_use,-build/c++11,-whitespace/parens $file;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册