From 2e7dc6660b4d695f24ea52aed6d0ee93d9001c7d Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 17 Oct 2022 15:04:14 +0800 Subject: [PATCH] skip ReplaceAllReduceOp in GraphtoBlock when nccl_ctxs_ is nullptr (#46911) * skip ReplaceAllReduceOp in GraphtoBlock when nccl_ctxs_ is nullptr * update ut * test_dist_allreduce_op failed * fix test_dist_allreduce_op * add ut * fix nccl cpu compile * fix --- paddle/fluid/framework/details/nccl_op_handle.h | 4 ++++ paddle/fluid/framework/ir/graph_helper.cc | 17 ++++++++++++++++- python/paddle/fluid/executor.py | 9 --------- .../fleet/test_fleet_graph_executor.py | 6 ++++-- .../fluid/tests/unittests/test_dist_base.py | 3 ++- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/details/nccl_op_handle.h b/paddle/fluid/framework/details/nccl_op_handle.h index 03e2ee1b9d..5ed4d80203 100644 --- a/paddle/fluid/framework/details/nccl_op_handle.h +++ b/paddle/fluid/framework/details/nccl_op_handle.h @@ -81,6 +81,10 @@ class NCCLOpHandleBase : public OpHandleBase { 0, platform::errors::Unimplemented( "Not supported use_hierarchical_allreduce_ now")); + PADDLE_ENFORCE_NOT_NULL( + nccl_ctxs_, + platform::errors::NotFound("Can't get flat %d nccl contexts.", + run_order_)); auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_); int dev_id = places_[0].device; auto& nccl_ctx = flat_nccl_ctxs->at(dev_id); diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 37dbfe2758..3db9814374 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -514,6 +514,17 @@ void ReplaceAllReduceOp(const Node &node, details::OpHandleBase &op_handle = const_cast(&node)->Wrapper(); + // Even if PADDLE_WITH_NCCL is defined, if the program runs on CPU, + // nccl_ctxs_ in NCCLOpHandleBase will be nullptr, and calling the + // GetComm() method will report an error. + // There is bugs in all_reduce_op_handle method on CPU devices, skip + // this case in temporary. + if (dynamic_cast(&op_handle)->GetNcclContext() == + nullptr) { + VLOG(4) << "Skip replacing allreduce op because nccl_ctxs_ is nullptr."; + return; + } + std::string all_reduce_var_name; // If fused, add check_memory_continue OP to fuse inputs if (is_fused) { @@ -637,10 +648,14 @@ static void GetGraphOpDesc(const std::vector &nodes, ops->emplace_back(); auto &desc = ops->back(); ReplaceScaleLossGradOp(*n, &desc); - } else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + } else if ((n->Name() == "allreduce" || n->Name() == "fused_all_reduce") && + dynamic_cast( + &(n->Wrapper())) != nullptr) { VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum"; ReplaceAllReduceOp(*n, block, ops); VLOG(4) << n->ToString(); +#endif } else if (n->Op()) { VLOG(4) << "convert op node to desc " << n->Op()->Type(); if (is_fused_opt(n)) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 09fe42e705..d04fab724a 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1566,15 +1566,6 @@ class Executor(object): compiled_program = program if isinstance( program, compiler.CompiledProgram) else program._graph - # delete this code after supporting distribution - if compiled_program._build_strategy is not None and ( - compiled_program._build_strategy.is_distribution - or compiled_program._build_strategy.num_trainers > 1): - warnings.warn( - "Standalone executor is not used for distribution", - UserWarning) - return use_standalone_executor_for_distribution - # Unsupported case 1: data parallel if compiled_program._is_data_parallel and len( compiled_program._get_places( diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py index af2a8a1465..3d30baba9a 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py @@ -29,7 +29,8 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", "http_proxy": "", - "https_proxy": "" + "https_proxy": "", + "FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1" } node_b = { @@ -38,7 +39,8 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", "http_proxy": "", - "https_proxy": "" + "https_proxy": "", + "FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1" } def node_func(): diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 933d5c6046..59b961fcb0 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1477,7 +1477,8 @@ class TestDistBase(unittest.TestCase): "FLAGS_rpc_disable_reuse_port": "1", "http_proxy": "", "NCCL_P2P_DISABLE": "1", - "NCCL_SHM_DISABLE": "1" + "NCCL_SHM_DISABLE": "1", + "FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1" } if check_error_log: -- GitLab