提交 1ef6cdb6 编写于 作者: Y Yancey1989

move dist codes from operaotrs/detail to operators/distributed

上级 4b7ae145
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_client.h"
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>()
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendComplete();
}
#endif
......
......@@ -184,8 +184,8 @@ else()
set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
......
......@@ -15,13 +15,13 @@
#pragma once
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
#define RPCSERVER_T distributed::AsyncGRPCServer
#define RPCCLIENT_T distributed::GRPCClient
#else
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc_server.h"
#define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T distributed::BRPCClient
#endif
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
......@@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
return q;
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -31,13 +31,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
struct ChannelContext {
brpc::Channel channel;
......@@ -95,6 +95,6 @@ class BRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(BRPCClient);
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv {
typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*>
paddle::operators::distributed::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
......@@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService {
: request_send_h_(nullptr),
request_get_h_(nullptr),
request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch);
it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
}
......@@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService {
}
private:
paddle::operators::detail::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_;
paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::distributed::RequestHandler* request_prefetch_h_;
};
} // namespace sendrecv
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
......@@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
}
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -19,12 +19,12 @@ limitations under the License. */
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class AsyncBRPCServer final : public RPCServer {
public:
......@@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer {
int ready_;
};
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -17,11 +17,11 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
GrpcByteBufferSource::GrpcByteBufferSource() {}
......@@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_;
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -106,7 +106,7 @@ class GrpcBufferReader final
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
......@@ -183,6 +183,6 @@ class GrpcByteSource : public Source {
char space_[sizeof(Reader)];
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
void GRPCClient::InitImpl() { InitEventLoop(); }
......@@ -276,6 +276,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
return ch;
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -38,13 +38,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
struct VarHandle {
std::string ep;
......@@ -226,6 +226,6 @@ class GRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(GRPCClient);
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
......@@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize
......@@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::distributed::VariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar();
......@@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize
......@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy
framework::Scope scope;
scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx);
operators::distributed::VariableResponse resp(&scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
......
......@@ -15,13 +15,13 @@ limitations under the License. */
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter;
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
enum CallStatus { PROCESS = 0, FINISH };
// reference:
......@@ -74,7 +74,7 @@ class RequestSend final : public RequestBase {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(),
!request_handler->sync_mode()));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
......@@ -106,7 +106,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
......@@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase {
local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
......@@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest(
}
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -29,17 +29,17 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/grpc_service.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class RequestBase;
......@@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer {
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
};
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -42,16 +42,17 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> {
class SerializationTraits<paddle::operators::distributed::VariableResponse> {
public:
static Status Serialize(
const paddle::operators::detail::VariableResponse& msg,
const paddle::operators::distributed::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status();
}
static Status Deserialize(grpc_byte_buffer* buffer,
paddle::operators::detail::VariableResponse* msg,
static Status Deserialize(
grpc_byte_buffer* buffer,
paddle::operators::distributed::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload");
......@@ -59,7 +60,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
Status result = g_core_codegen_interface->ok();
if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer);
paddle::operators::distributed::GrpcByteSource source(buffer);
int ret = msg->Parse(&source);
if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
......@@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
enum class GrpcMethod {
kSendVariable,
......@@ -118,6 +119,6 @@ class GrpcService final {
};
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds
......@@ -144,6 +144,6 @@ class ProtoEncodeHelper {
char* limit_; // Just for CHECKs
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -31,7 +31,7 @@
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
......@@ -124,6 +124,6 @@ class RequestHandler {
RPCServer* rpc_server_;
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -20,12 +20,12 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
......@@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -28,11 +28,11 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class RequestSendHandler final : public RequestHandler {
public:
......@@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override;
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -22,7 +22,7 @@
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class RPCClient {
public:
......@@ -84,6 +84,6 @@ class RPCClient {
static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_;
};
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -17,11 +17,11 @@
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown ";
......@@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -19,11 +19,11 @@
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class RPCServer {
public:
......@@ -86,6 +86,6 @@ class RPCServer {
friend class RequestHandler;
};
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -22,18 +22,18 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;
namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table);
std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
......@@ -113,19 +113,21 @@ void StartServer() {
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get());
g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true));
g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
......@@ -23,14 +23,14 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
using VarMsg = sendrecv::VariableMessage;
......@@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx);
operators::distributed::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -25,12 +25,12 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
typedef void (*DestroyCallback)(void*);
......@@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
}
}
} // namespace detail
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include <string>
#include <utility>
......@@ -22,12 +22,12 @@
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
enum WireType {
WIRETYPE_VARINT = 0,
......@@ -158,13 +158,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
static_cast<size_t>(tensor->numel()),
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
paddle::operators::distributed::ToTypeIndex(
meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
......@@ -480,6 +480,6 @@ int VariableResponse::Parse(Source* source) {
return 0;
}
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -22,17 +22,17 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
namespace distributed {
class VariableResponse {
public:
......@@ -99,6 +99,6 @@ class VariableResponse {
sendrecv::VariableMessage meta_;
};
}; // namespace detail
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait();
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
......@@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
......@@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail::RequestSendHandler rpc_h(true);
std::unique_ptr<detail::RPCServer> rpc_service(
distributed::RequestSendHandler rpc_h(true);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
......@@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, rpc_service.get()));
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(detail::kRequestSend);
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(detail::kRequestSend);
rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
......
......@@ -21,14 +21,14 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
void RunServer(std::shared_ptr<detail::RPCServer> service) {
void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
......@@ -121,12 +121,12 @@ void ListenAndServOp::RunSyncLoop(
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(detail::kRequestSend);
rpc_service_->WaitBarrier(detail::kRequestSend);
rpc_service_->SetCond(distributed::kRequestSend);
rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
......@@ -154,11 +154,11 @@ void ListenAndServOp::RunSyncLoop(
recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
} // while(true)
}
......@@ -215,13 +215,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
}
static void FillRequestCtx(
detail::RequestHandler *h, framework::Scope *scope,
distributed::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
detail::RPCServer *rpc_server) {
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
......@@ -249,14 +249,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode));
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
new distributed::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
......
......@@ -24,8 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle {
namespace operators {
......@@ -33,7 +33,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
void RunServer(std::shared_ptr<distributed::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase {
public:
......@@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase {
const platform::Place& dev_place) const override;
protected:
mutable std::shared_ptr<detail::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
};
......
......@@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......
......@@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
......
......@@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
......@@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
......@@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail;
namespace distributed = paddle::operators::distributed;
namespace string = paddle::string;
std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
void StartServer() {
f::Scope scope;
......@@ -57,14 +57,14 @@ void StartServer() {
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get());
g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend);
g_rpc_service->WaitBarrier(detail::kRequestSend);
g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(distributed::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
......@@ -72,7 +72,7 @@ void StartServer() {
}
TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new detail::RequestSendHandler(true));
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
......@@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) {
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册