未验证 提交 2e7dc666 编写于 作者: P pangyoki 提交者: GitHub

skip ReplaceAllReduceOp in GraphtoBlock when nccl_ctxs_ is nullptr (#46911)

* skip ReplaceAllReduceOp in GraphtoBlock when nccl_ctxs_ is nullptr

* update ut

* test_dist_allreduce_op failed

* fix test_dist_allreduce_op

* add ut

* fix nccl cpu compile

* fix
上级 198c7993
......@@ -81,6 +81,10 @@ class NCCLOpHandleBase : public OpHandleBase {
0,
platform::errors::Unimplemented(
"Not supported use_hierarchical_allreduce_ now"));
PADDLE_ENFORCE_NOT_NULL(
nccl_ctxs_,
platform::errors::NotFound("Can't get flat %d nccl contexts.",
run_order_));
auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
int dev_id = places_[0].device;
auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
......
......@@ -514,6 +514,17 @@ void ReplaceAllReduceOp(const Node &node,
details::OpHandleBase &op_handle =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
// Even if PADDLE_WITH_NCCL is defined, if the program runs on CPU,
// nccl_ctxs_ in NCCLOpHandleBase will be nullptr, and calling the
// GetComm() method will report an error.
// There is bugs in all_reduce_op_handle method on CPU devices, skip
// this case in temporary.
if (dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetNcclContext() ==
nullptr) {
VLOG(4) << "Skip replacing allreduce op because nccl_ctxs_ is nullptr.";
return;
}
std::string all_reduce_var_name;
// If fused, add check_memory_continue OP to fuse inputs
if (is_fused) {
......@@ -637,10 +648,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
} else if ((n->Name() == "allreduce" || n->Name() == "fused_all_reduce") &&
dynamic_cast<details::NCCLOpHandleBase *>(
&(n->Wrapper<details::OpHandleBase>())) != nullptr) {
VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum";
ReplaceAllReduceOp(*n, block, ops);
VLOG(4) << n->ToString();
#endif
} else if (n->Op()) {
VLOG(4) << "convert op node to desc " << n->Op()->Type();
if (is_fused_opt(n)) {
......
......@@ -1566,15 +1566,6 @@ class Executor(object):
compiled_program = program if isinstance(
program, compiler.CompiledProgram) else program._graph
# delete this code after supporting distribution
if compiled_program._build_strategy is not None and (
compiled_program._build_strategy.is_distribution
or compiled_program._build_strategy.num_trainers > 1):
warnings.warn(
"Standalone executor is not used for distribution",
UserWarning)
return use_standalone_executor_for_distribution
# Unsupported case 1: data parallel
if compiled_program._is_data_parallel and len(
compiled_program._get_places(
......
......@@ -29,7 +29,8 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
"https_proxy": "",
"FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1"
}
node_b = {
......@@ -38,7 +39,8 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
"https_proxy": "",
"FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1"
}
def node_func():
......
......@@ -1477,7 +1477,8 @@ class TestDistBase(unittest.TestCase):
"FLAGS_rpc_disable_reuse_port": "1",
"http_proxy": "",
"NCCL_P2P_DISABLE": "1",
"NCCL_SHM_DISABLE": "1"
"NCCL_SHM_DISABLE": "1",
"FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1"
}
if check_error_log:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册