提交 2378aa8a 编写于 作者: W WangXi 提交者: gongweibao

[Cherry-pick 1.6] Fix DGC test and DGC nan bug (#20708)

上级 8bc2cbd8
...@@ -23,8 +23,8 @@ INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR}) ...@@ -23,8 +23,8 @@ INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add( ExternalProject_Add(
extern_dgc extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
URL "http://fleet.bj.bcebos.com/collective.tgz" URL "http://fleet.bj.bcebos.com/collective_ef2216a.tgz"
URL_MD5 "015d565156c3de4e30fe25473f47e7a9" URL_MD5 "2f67549fd5f1262383d83289abc4f88f"
SOURCE_DIR "${DGC_SOURCES_DIR}" SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND make -j BUILD_COMMAND make -j
......
...@@ -171,6 +171,10 @@ void AllReduceOpHandle::NCCLAllReduceFunc( ...@@ -171,6 +171,10 @@ void AllReduceOpHandle::NCCLAllReduceFunc(
} }
}); });
SyncNCCLAllReduce();
}
void AllReduceOpHandle::SyncNCCLAllReduce() {
if (FLAGS_sync_nccl_allreduce) { if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) { for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
......
...@@ -63,6 +63,8 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -63,6 +63,8 @@ class AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void NCCLAllReduceFunc( void NCCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls); const std::vector<std::function<void()>> &all_reduce_calls);
void SyncNCCLAllReduce();
#endif #endif
void AllReduceImpl(const std::vector<VarHandle *> &in_var_handles, void AllReduceImpl(const std::vector<VarHandle *> &in_var_handles,
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -105,7 +106,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -105,7 +106,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
size_t in_numel = 0; size_t in_numel = 0;
size_t out_numel = 0; size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1); PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_gather_calls;
std::vector<std::function<void()>> sparse_reduce_calls;
std::vector<memory::AllocationPtr> allocations; std::vector<memory::AllocationPtr> allocations;
...@@ -141,15 +143,45 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -141,15 +143,45 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype; << ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] { all_gather_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm, in_tensor_buf, gather_buff, 2 * k, static_cast<ncclDataType_t>(dtype),
stream)); comm, stream));
});
sparse_reduce_calls.emplace_back([=] {
platform::CUDADeviceGuard guard(dev_id);
PADDLE_ENFORCE_EQ(paddle::communication::dgc::sparseReduce(
gather_buff, k, out_tensor_buf,
static_cast<int>(out_numel), nranks_, stream),
true);
}); });
} }
WaitInputVarGenerated(); WaitInputVarGenerated();
NCCLAllReduceFunc(all_reduce_calls); SparseAllReduceFunc(all_gather_calls, sparse_reduce_calls);
}
void SparseAllReduceOpHandle::SparseAllReduceFunc(
const std::vector<std::function<void()>> &all_gather_calls,
const std::vector<std::function<void()>> &sparse_reduce_calls) {
this->RunAndRecordEvent([&] {
if (all_gather_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_gather_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_gather_calls) {
call();
}
}
for (auto &call : sparse_reduce_calls) {
call();
}
});
SyncNCCLAllReduce();
} }
int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) { int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
......
...@@ -43,6 +43,9 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { ...@@ -43,6 +43,9 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
int GetKValue(const std::string &grad_name); int GetKValue(const std::string &grad_name);
bool IsEncoded(); bool IsEncoded();
void RunImplEncoded(); void RunImplEncoded();
void SparseAllReduceFunc(
const std::vector<std::function<void()>> &all_gather_calls,
const std::vector<std::function<void()>> &sparse_reduce_calls);
private: private:
bool is_encoded_{false}; bool is_encoded_{false};
......
...@@ -465,8 +465,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -465,8 +465,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
new details::SparseAllReduceOpHandle( new details::SparseAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, multi_nccl_ctxs_, is_encoded, scopes, places, multi_nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) * strategy_.num_trainers_ * places_.size()));
places_.size()));
} else { } else {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new details::AllReduceOpHandle( new details::AllReduceOpHandle(
......
...@@ -32,9 +32,6 @@ limitations under the License. */ ...@@ -32,9 +32,6 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
#if defined(PADDLE_WITH_DGC)
#include "dgc/dgc.h"
#endif
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
DEFINE_int32(multiple_of_cupti_buffer_size, 1, DEFINE_int32(multiple_of_cupti_buffer_size, 1,
...@@ -51,10 +48,6 @@ namespace framework { ...@@ -51,10 +48,6 @@ namespace framework {
std::once_flag gflags_init_flag; std::once_flag gflags_init_flag;
std::once_flag p2p_init_flag; std::once_flag p2p_init_flag;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::once_flag dgc_init_flag;
#endif
void InitGflags(std::vector<std::string> argv) { void InitGflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
FLAGS_logtostderr = true; FLAGS_logtostderr = true;
...@@ -229,15 +222,5 @@ void InitGLOG(const std::string &prog_name) { ...@@ -229,15 +222,5 @@ void InitGLOG(const std::string &prog_name) {
#endif #endif
} }
#if defined(PADDLE_WITH_DGC)
void InitDGC() {
std::call_once(dgc_init_flag, []() {
PADDLE_ENFORCE(paddle::communication::dgc::dynloadNcclLib());
});
}
#else
void InitDGC() {}
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -30,8 +30,6 @@ void InitDevices(bool init_p2p); ...@@ -30,8 +30,6 @@ void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices); void InitDevices(bool init_p2p, const std::vector<int> devices);
void InitDGC();
#ifndef _WIN32 #ifndef _WIN32
void SignalHandle(const char *data, int size); void SignalHandle(const char *data, int size);
#endif #endif
......
...@@ -1352,7 +1352,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1352,7 +1352,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags); m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG); m.def("init_glog", framework::InitGLOG);
m.def("init_dgc", framework::InitDGC);
m.def("load_op_library", framework::LoadOpLib); m.def("load_op_library", framework::LoadOpLib);
m.def("init_devices", m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); }); [](bool init_p2p) { framework::InitDevices(init_p2p); });
......
...@@ -271,7 +271,6 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -271,7 +271,6 @@ class CollectiveOptimizer(DistributedOptimizer):
node_num = self._node_num() node_num = self._node_num()
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
self._strategy.fuse_all_reduce_ops = True
exec_strategy = self._strategy.exec_strategy exec_strategy = self._strategy.exec_strategy
if node_num <= 1: if node_num <= 1:
......
...@@ -959,8 +959,6 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -959,8 +959,6 @@ class DGCMomentumOptimizer(MomentumOptimizer):
super(DGCMomentumOptimizer, self).__init__( super(DGCMomentumOptimizer, self).__init__(
learning_rate, momentum, use_nesterov, regularization, name) learning_rate, momentum, use_nesterov, regularization, name)
core.init_dgc()
def _add_auto_increment_var(self, counter_name, begin, step=1): def _add_auto_increment_var(self, counter_name, begin, step=1):
helper = LayerHelper('global_step_counter') helper = LayerHelper('global_step_counter')
counter, is_new_var = helper.create_or_get_global_variable( counter, is_new_var = helper.create_or_get_global_variable(
......
...@@ -291,6 +291,10 @@ class TestDistRunnerBase(object): ...@@ -291,6 +291,10 @@ class TestDistRunnerBase(object):
build_stra.num_trainers = 1 build_stra.num_trainers = 1
build_stra.trainer_id = 0 build_stra.trainer_id = 0
if args.use_dgc:
# fuse_all_reduce_ops require that gradients should not be sparse types
build_stra.fuse_all_reduce_ops = False
print_to_err(type(self).__name__, "begin to compile with data parallel") print_to_err(type(self).__name__, "begin to compile with data parallel")
binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
...@@ -572,7 +576,8 @@ class TestDistBase(unittest.TestCase): ...@@ -572,7 +576,8 @@ class TestDistBase(unittest.TestCase):
check_error_log=False, check_error_log=False,
batch_size=DEFAULT_BATCH_SIZE, batch_size=DEFAULT_BATCH_SIZE,
batch_merge_repeat=1, batch_merge_repeat=1,
log_name=""): log_name="",
gpus="0"):
cmd = self._python_interp cmd = self._python_interp
...@@ -592,13 +597,17 @@ class TestDistBase(unittest.TestCase): ...@@ -592,13 +597,17 @@ class TestDistBase(unittest.TestCase):
if self.__use_cuda: if self.__use_cuda:
cmd += " --use_cuda" cmd += " --use_cuda"
env_local = { env_local = {
"CUDA_VISIBLE_DEVICES": "0", "CUDA_VISIBLE_DEVICES": gpus,
"PADDLE_TRAINERS_NUM": "1", "PADDLE_TRAINERS_NUM": "1",
"PADDLE_TRAINER_ID": "0" "PADDLE_TRAINER_ID": "0"
} }
else: else:
env_local = {'CPU_NUM': '1'} env_local = {'CPU_NUM': '1'}
# not use dgc in single card
if len(gpus) > 1 and self._use_dgc:
cmd += " --use_dgc"
env_local.update(envs) env_local.update(envs)
print("local_cmd: {}, env: {}".format(cmd, env_local)) print("local_cmd: {}, env: {}".format(cmd, env_local))
...@@ -825,12 +834,7 @@ class TestDistBase(unittest.TestCase): ...@@ -825,12 +834,7 @@ class TestDistBase(unittest.TestCase):
print("outs[1]:", outs[1]) print("outs[1]:", outs[1])
return pickle.loads(outs[0]), pickle.loads(outs[1]) return pickle.loads(outs[0]), pickle.loads(outs[1])
def check_with_place(self, def _get_required_envs(self, check_error_log=False, need_envs={}):
model_file,
delta=1e-3,
check_error_log=False,
need_envs={},
log_name=""):
# TODO(typhoonzero): should auto adapt GPU count on the machine. # TODO(typhoonzero): should auto adapt GPU count on the machine.
required_envs = { required_envs = {
"PATH": os.getenv("PATH", ""), "PATH": os.getenv("PATH", ""),
...@@ -844,13 +848,24 @@ class TestDistBase(unittest.TestCase): ...@@ -844,13 +848,24 @@ class TestDistBase(unittest.TestCase):
"NCCL_SHM_DISABLE": "1" "NCCL_SHM_DISABLE": "1"
} }
required_envs.update(need_envs)
if check_error_log: if check_error_log:
required_envs["GLOG_vmodule"] = \ required_envs["GLOG_vmodule"] = \
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10" "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs.update(need_envs)
return required_envs
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={},
log_name=""):
required_envs = self._get_required_envs(check_error_log, need_envs)
local_losses \ local_losses \
= self._run_local(model_file, required_envs, = self._run_local(model_file, required_envs,
check_error_log, log_name=log_name) check_error_log, log_name=log_name)
...@@ -881,3 +896,38 @@ class TestDistBase(unittest.TestCase): ...@@ -881,3 +896,38 @@ class TestDistBase(unittest.TestCase):
dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2 dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
print("=======", local_loss, ":", dist_loss[0], "=======") print("=======", local_loss, ":", dist_loss[0], "=======")
self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta) self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)
def check_with_place_multi_cards(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={},
log_name=""):
# need open p2p or shm otherwise multi cards mode will hang
need_envs.update({"NCCL_P2P_DISABLE": "0", "NCCL_SHM_DISABLE": "0"})
required_envs = self._get_required_envs(check_error_log, need_envs)
if self._use_dgc:
multi_cards_losses = self._run_local(
model_file,
required_envs,
check_error_log,
log_name=log_name + "_dgc_2cards",
gpus="0,1")
self._use_dgc = False
base_losses = self._run_local(
model_file,
required_envs,
check_error_log,
log_name=log_name + "_base_2cards",
gpus="0,1")
self._use_dgc = True
for step_id in range(RUN_STEP):
base_loss = base_losses[step_id]
multi_cards_loss = multi_cards_losses[step_id]
print("=======", base_loss, ":", multi_cards_loss, "=======")
self.assertAlmostEqual(base_loss, multi_cards_loss, delta=delta)
...@@ -38,5 +38,23 @@ class TestDistMnistNCCL2DGC(TestDistBase): ...@@ -38,5 +38,23 @@ class TestDistMnistNCCL2DGC(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestDistMnistNCCL2DGCMultiCards(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_dgc = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place_multi_cards(
"dist_mnist.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册