提交 17009d06 编写于 作者: T typhoonzero

workable version

上级 a529d790
...@@ -205,6 +205,7 @@ if(WITH_DISTRIBUTE) ...@@ -205,6 +205,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/string/printf.h"
USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail;
namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
void StartServer() {
f::Scope scope;
p::CPUPlace place;
scope.Var("NCCLID");
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
rpc_service->SetCond(0);
auto recv = rpc_service->Get();
LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown();
server_thread.join();
}
TEST(SendNcclId, Normal) {
std::thread server_thread(StartServer);
// wait server to start
sleep(2);
f::Scope scope;
p::CPUPlace place;
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var("NCCLID");
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
client.AsyncSendVariable(ep, dev_ctx, scope, "NCCLID");
client.Wait();
server_thread.join();
auto* ptr = rpc_service.release();
delete ptr;
}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <stdio.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
...@@ -100,13 +101,13 @@ struct NCCLContextMap { ...@@ -100,13 +101,13 @@ struct NCCLContextMap {
} }
} else { } else {
PADDLE_ENFORCE_GT(node_count, 0); PADDLE_ENFORCE_GT(node_count, 0);
PADDLE_ENFORCE_EQ(node_count % places.size(), 0, // TODO(wuyi): need to ensure each node have same number of GPUs
"must have same number of GPUs on each node");
{ {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
int nranks = node_count * order_.size(); int nranks = node_count * order_.size();
NCCLGroupGuard gurad;
for (auto &gpu_id : order_) { for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id; int rank = trainer_id * order_.size() + gpu_id;
VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
PADDLE_ENFORCE(cudaSetDevice(gpu_id)); PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
comms.get() + gpu_id, nranks, *nccl_id, rank)); comms.get() + gpu_id, nranks, *nccl_id, rank));
......
...@@ -53,6 +53,11 @@ class ParallelExecutor(object): ...@@ -53,6 +53,11 @@ class ParallelExecutor(object):
gradients of each device and scaled gradients would be gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed aggregated. Otherwise, a customized scale value should be fed
to the network. to the network.
num_nodes(int, default 0): If greater than 0, NCCL will be
initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_nodes.
trainer_id is the "rank" of current node starts from 0.
Returns: Returns:
A ParallelExecutor object. A ParallelExecutor object.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册