未验证 提交 7287630e 编写于 作者: Q QI JUN 提交者: GitHub

Repair nccl op test (#8575)

* fix nccl op unit test

* fix build error

* format code

* refine nccl related unit test

* fix build error

* add setGPUData

* clean up

* follow comments

* rm test_nccl.cu

* follow comment

* rm wait
上级 ada82a3e
......@@ -244,11 +244,11 @@ function(cc_test TARGET_NAME)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
if("${cc_test_DEPS}" MATCHES "ARCHIVE_START")
list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END)
endif()
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
......@@ -311,8 +311,8 @@ function(nv_test TARGET_NAME)
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(nv_test)
......
......@@ -222,8 +222,6 @@ cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
if(WITH_GPU)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
......
......@@ -14,19 +14,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <mutex>
#include <thread>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -41,26 +37,35 @@ USE_CUDA_ONLY_OP(ncclBcast);
namespace f = paddle::framework;
namespace p = paddle::platform;
static std::vector<int> gpu_list;
// test data amount
const f::DDim kDims = {100, 100};
const f::DDim kDims = {20, 20};
// nccl op common tester, init communicator.
class NCCLTester : public ::testing::Test {
public:
virtual void SetUp() override {
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< count;
exit(0);
}
for (int i = 0; i < count; ++i) {
gpu_list_.emplace_back(i);
}
paddle::platform::CPUPlace cpu_place;
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
p::CUDAPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
dev_ctxs_.emplace_back(new p::CUDADeviceContext(place));
}
NCCLInitOp();
}
virtual void TearDown() override {
for (auto &device_context : dev_ctxs) {
for (auto &device_context : dev_ctxs_) {
delete device_context;
}
}
......@@ -70,36 +75,40 @@ class NCCLTester : public ::testing::Test {
std::unique_ptr<f::OpDesc> op1(new f::OpDesc);
op1->SetType("ncclInit");
op1->SetInput("parallel_scopes", {"p_scopes"});
op1->SetOutput("Communicator", {"comm"});
op1->SetAttr("gpus", {gpu_list});
auto *var = g_scope.Var("comm");
auto *var = g_scope_.Var("comm");
var->GetMutable<p::Communicator>();
auto *scope_var = g_scope_.Var("p_scopes");
auto *p_scopes = scope_var->GetMutable<std::vector<f::Scope *>>();
(*p_scopes).resize(gpu_list_.size());
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, cpu_place);
op->Run(g_scope_, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
int GetGPUData(int gpu_id) { return gpu_id + 42; }
template <class T>
void PerThreadProgram(int gpu_id, const f::OpDesc &op_desc, f::Scope *scope) {
std::unique_lock<std::mutex> lk(mu);
std::unique_lock<std::mutex> lk(mu_);
const f::OpDesc *op1 = &op_desc;
p::CUDAPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id);
auto &ctx = dev_ctxs_.at(gpu_id);
auto *send_tensor = scope->Var("st")->GetMutable<f::LoDTensor>();
auto *recv_tensor = scope->Var("rt")->GetMutable<f::LoDTensor>();
if (!send_tensor->numel()) {
send_tensor->Resize(kDims);
send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id);
std::vector<T> send_vector(f::product(kDims), GetGPUData(gpu_id));
paddle::framework::TensorFromVector<T>(send_vector, *ctx, send_tensor);
ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
}
......@@ -118,30 +127,14 @@ class NCCLTester : public ::testing::Test {
}
public:
std::vector<p::DeviceContext *> dev_ctxs;
f::Scope g_scope;
std::mutex mu;
std::vector<p::DeviceContext *> dev_ctxs_;
f::Scope g_scope_;
std::mutex mu_;
std::vector<int> gpu_list_;
};
// ncclInitOp with desc
TEST(NCCL, ncclInitOp) {
std::unique_ptr<f::OpDesc> op_desc(new f::OpDesc);
op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
paddle::platform::CPUPlace cpu_place;
auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
TEST_F(NCCLTester, ncclInitOp) {}
// ncclAllReduceOp with desc
TEST_F(NCCLTester, ncclAllReduceOp) {
......@@ -155,23 +148,25 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
// check results
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
float expected_result = 0.0;
for (int gpu_id : gpu_list_) {
expected_result = expected_result + GetGPUData(gpu_id);
}
for (size_t i = 0; i < dev_scopes.size(); ++i) {
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[i]);
p::CUDAPlace gpu_place(gpu_list_[i]);
auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -180,12 +175,12 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[i]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[i])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[i])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
ASSERT_NEAR(ct[j], expected_result, 1e-5);
}
}
}
......@@ -204,22 +199,24 @@ TEST_F(NCCLTester, ncclReduceOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
// check results on
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
float expected_result = 0.0;
for (int gpu_id : gpu_list_) {
expected_result = expected_result + GetGPUData(gpu_id);
}
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[kRoot]);
p::CUDAPlace gpu_place(gpu_list_[kRoot]);
auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -229,12 +226,12 @@ TEST_F(NCCLTester, ncclReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[kRoot]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[kRoot]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[kRoot])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[kRoot])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
ASSERT_NEAR(ct[j], expected_result, 1e-5);
}
}
......@@ -252,23 +249,22 @@ TEST_F(NCCLTester, ncclBcastOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
const int idx = 1;
// check results on
float result = kRoot;
float result = GetGPUData(kRoot);
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[idx]);
p::CUDAPlace gpu_place(gpu_list_[idx]);
auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -277,42 +273,11 @@ TEST_F(NCCLTester, ncclBcastOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[idx]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[idx])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
int main(int argc, char **argv) {
// FIXME(tonyyang-svail):
// Due to the driver issue on our CI, disable for now
return 0;
const int dev_count = p::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::CUDAPlace(i));
gpu_list.emplace_back(i);
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
// device context should be release before scope.
// otherwise driver will down.
return RUN_ALL_TESTS();
}
......@@ -48,7 +48,6 @@ nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context)
cc_library(device_tracer SRCS device_tracer.cc DEPS profiler_proto ${GPU_CTX_DEPS})
cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
static int dev_count = 0;
namespace paddle {
namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
template <typename T>
struct PerThreadData {
thrust::device_vector<T> send_buff;
thrust::device_vector<T> recv_buff;
CUDADeviceContext dev_ctx;
T* SendBuff() { return thrust::raw_pointer_cast(send_buff.data()); }
T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); }
PerThreadData(int gpu_id, size_t size) : dev_ctx(CUDAPlace(gpu_id)) {
send_buff.resize(size);
for (size_t i = 0; i < size; ++i) {
send_buff[i] = static_cast<T>(i);
}
recv_buff.resize(size);
}
};
static constexpr int ELEM_COUNT = 10000;
TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm";
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data;
data.reserve(dev_count);
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Creating thread data for device " << i;
SetDeviceId(i);
data.emplace_back(new PerThreadData<double>(i, ELEM_COUNT));
}
VLOG(1) << "Thread data created";
VLOG(1) << "Check send_buf data";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Check on device " << i;
SetDeviceId(i);
thrust::host_vector<double> tmp = data[i]->send_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
ASSERT_NEAR(static_cast<double>(j), tmp[j], 1e-5);
}
}
VLOG(1) << "Invoking ncclAllReduce";
dynload::ncclGroupStart();
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Invoking ncclAllReduce with device " << i;
SetDeviceId(i);
PADDLE_ENFORCE(dynload::ncclAllReduce(
data[i]->SendBuff(), data[i]->RecvBuff(), ELEM_COUNT, ncclDouble,
ncclSum, comms[i], data[i]->dev_ctx.stream()));
VLOG(1) << "Invoked ncclAllReduce for device " << i;
}
dynload::ncclGroupEnd();
VLOG(1) << "Invoked ncclAllReduce";
VLOG(1) << "Sync devices";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Sync device " << i;
SetDeviceId(i);
data[i]->dev_ctx.Wait();
}
VLOG(1) << "device synced";
for (int i = 0; i < dev_count; ++i) {
SetDeviceId(i);
VLOG(1) << "Checking vector on device " << i;
thrust::host_vector<double> tmp = data[i]->recv_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
auto elem = static_cast<double>(j);
elem *= dev_count;
ASSERT_NEAR(tmp[j], elem, 1e-4);
}
}
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
} // namespace platform
} // namespace paddle
int main(int argc, char** argv) {
dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::CUDAPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册