From 500f070d55ee1c20029d21fd165fd793e68e2b98 Mon Sep 17 00:00:00 2001 From: qipengh Date: Wed, 7 Sep 2022 11:32:50 +0800 Subject: [PATCH] [MLU] fix sync_bn of mlu and add unittests (#45707) * [MLU] fix sync_bn of mlu and add unittests * [MLU] remove redunant code of pytest --- .../fluid/operators/sync_batch_norm_op_mlu.cc | 57 +++--- python/paddle/fluid/framework.py | 4 +- .../mlu/parallel_dygraph_sync_batch_norm.py | 105 ++++++++++ ...st_parallel_dygraph_sync_batch_norm_mlu.py | 192 ++++++++++++++++++ .../mlu/test_sync_batch_norm_base_mlu.py | 38 ++-- .../mlu/test_sync_batch_norm_op_mlu.sh | 2 + .../fluid/tests/unittests/test_dist_base.py | 37 +++- 7 files changed, 388 insertions(+), 47 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mlu/parallel_dygraph_sync_batch_norm.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_parallel_dygraph_sync_batch_norm_mlu.py diff --git a/paddle/fluid/operators/sync_batch_norm_op_mlu.cc b/paddle/fluid/operators/sync_batch_norm_op_mlu.cc index a2091aa10a..0a95088c31 100644 --- a/paddle/fluid/operators/sync_batch_norm_op_mlu.cc +++ b/paddle/fluid/operators/sync_batch_norm_op_mlu.cc @@ -159,9 +159,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel { GetBasePtr(&local_var)); Tensor input_count; - input_count.mutable_data(phi::make_ddim({1}), ctx.GetPlace()); - FillMLUTensorWithHostValue( - ctx, static_cast(x->numel() / C), &input_count); + input_count.mutable_data(phi::make_ddim({1}), ctx.GetPlace()); + FillMLUTensorWithHostValue( + ctx, static_cast(x->numel() / C), &input_count); Tensor count_all; Tensor mean_all(mean->dtype()); @@ -170,15 +170,23 @@ class SyncBatchNormMLUKernel : public framework::OpKernel { #ifdef PADDLE_WITH_CNCL auto &dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); auto *comm = dev_ctx.cncl_comm(); if (comm) { - auto *comm = paddle::platform::CNCLCommContext::Instance() - .Get(0, ctx.GetPlace()) - ->comm(); + auto cncl_comm = paddle::platform::CNCLCommContext::Instance().Get( + 0, ctx.GetPlace()); + auto *comm = cncl_comm->comm(); + auto comm_stream = cncl_comm->stream(); int count; PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm)); - count_all.mutable_data(phi::make_ddim({count}), ctx.GetPlace()); + count_all.mutable_data(phi::make_ddim({count}), + ctx.GetPlace()); + mean_all.mutable_data(phi::make_ddim({count, mean->numel()}), + ctx.GetPlace()); + invstd_all.mutable_data( + phi::make_ddim({count, variance->numel()}), ctx.GetPlace()); + // before comm_stream exec, need sync compute_stream. + dev_ctx.Wait(); + cnclDataType_t dtype = platform::ToCNCLDataType( framework::TransToProtoVarType(count_all.dtype())); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count), @@ -186,12 +194,7 @@ class SyncBatchNormMLUKernel : public framework::OpKernel { 1, dtype, comm, - stream)); - - mean_all.mutable_data(phi::make_ddim({count, mean->numel()}), - ctx.GetPlace()); - invstd_all.mutable_data( - phi::make_ddim({count, variance->numel()}), ctx.GetPlace()); + comm_stream)); auto cncl_dtype = platform::ToCNCLDataType( framework::TransToProtoVarType(mean_all.dtype())); @@ -200,14 +203,17 @@ class SyncBatchNormMLUKernel : public framework::OpKernel { local_mean.numel(), cncl_dtype, comm, - stream)); + comm_stream)); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var), GetBasePtr(&invstd_all), local_var.numel(), cncl_dtype, comm, - stream)); + comm_stream)); + // after comm_stream exec, need sync queue for using compute_stream + // correctly. + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream)); #else if (NO_USE_CNCL) { #endif @@ -412,12 +418,14 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel { #ifdef PADDLE_WITH_CNCL auto &dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); auto *comm = dev_ctx.cncl_comm(); if (comm) { - auto *comm = paddle::platform::CNCLCommContext::Instance() - .Get(0, ctx.GetPlace()) - ->comm(); + auto cncl_comm = + paddle::platform::CNCLCommContext::Instance().Get(0, ctx.GetPlace()); + auto *comm = cncl_comm->comm(); + auto comm_stream = cncl_comm->stream(); + // before comm_stream exec, need sync compute_stream. + dev_ctx.Wait(); cnclDataType_t dtype = platform::ToCNCLDataType( framework::TransToProtoVarType(numel_count.dtype())); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count), @@ -426,7 +434,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel { dtype, cnclSum, comm, - stream)); + comm_stream)); auto cncl_dtype = platform::ToCNCLDataType( framework::TransToProtoVarType(sum_dy.dtype())); @@ -436,7 +444,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel { cncl_dtype, cnclSum, comm, - stream)); + comm_stream)); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu), GetBasePtr(&sum_dy_xmu), @@ -444,7 +452,10 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel { cncl_dtype, cnclSum, comm, - stream)); + comm_stream)); + // after comm_stream exec, need sync queue for using compute_stream + // correctly. + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream)); } #endif diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 16c4fc6acb..bf56b125fd 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7220,9 +7220,9 @@ def device_guard(device=None): device, index = device.split(':') if device == 'cpu': raise ValueError("Should not set device id for cpu.") - if device not in ['cpu', 'gpu', 'npu', 'xpu', '', None]: + if device not in ['cpu', 'gpu', 'npu', 'xpu', 'mlu', '', None]: raise ValueError( - "The Attr(device) should be 'cpu' 'npu' 'xpu' or 'gpu', and it can also be empty string or None " + "The Attr(device) should be 'cpu' 'npu' 'xpu' 'mlu' or 'gpu', and it can also be empty string or None " "when there is no need to specify device. But received %s" % device) if index: device = ":".join([device, index]) diff --git a/python/paddle/fluid/tests/unittests/mlu/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/mlu/parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000..6f7c0d595c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.nn import Conv2D, Linear, SyncBatchNorm +from paddle.fluid.dygraph.base import to_variable +import sys + +sys.path.append("..") +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class TestLayer(fluid.dygraph.Layer): + + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(TestLayer, self).__init__() + + self._conv = Conv2D(in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + bias_attr=False) + + self._sync_batch_norm = SyncBatchNorm(num_filters) + + self._conv2 = Conv2D(in_channels=num_filters, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + bias_attr=False) + + self._sync_batch_norm2 = SyncBatchNorm(num_filters, + weight_attr=False, + bias_attr=False) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._sync_batch_norm(y) + y = self._conv2(y) + y = self._sync_batch_norm2(y) + + return y + + +class TestSyncBatchNorm(TestParallelDyGraphRunnerBase): + + def get_model(self): + model = TestLayer(3, 64, 7) + train_reader = paddle.batch(paddle.dataset.flowers.test(use_xmap=False), + batch_size=32, + drop_last=True) + opt = fluid.optimizer.Adam(learning_rate=1e-3, + parameter_list=model.parameters()) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + batch_size = len(data) + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + img = to_variable(dy_x_data) + img.stop_gradient = False + + out = model(img) + + out = paddle.mean(out) + + return out + + +if __name__ == "__main__": + runtime_main(TestSyncBatchNorm) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_parallel_dygraph_sync_batch_norm_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_parallel_dygraph_sync_batch_norm_mlu.py new file mode 100644 index 0000000000..73e41f7896 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_parallel_dygraph_sync_batch_norm_mlu.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys + +sys.path.append("..") +import unittest +from test_dist_base import TestDistBase +import paddle.fluid as fluid + +import os +import subprocess +import pickle + +DEFAULT_BATCH_SIZE = 2 + +flag_name = os.path.splitext(__file__)[0] + +print("file: {}".format(flag_name)) + + +class TestParallelDygraphMnistMLU(TestDistBase): + + def _setup_config(self): + self._sync_mode = False + self._cncl_mode = True + self._dygraph = True + self._enforce_place = "MLU" + + def _get_required_envs(self, check_error_log=False, need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_eager_delete_tensor_gb": "0.0", + "FLAGS_call_stack_level": "2", + "GLOG_v": "2", + "PADDLE_WITH_GLOO": '0', + "BACKEND": "cncl" + } + + if check_error_log: + required_envs["GLOG_v"] = "5" + required_envs["GLOG_logtostderr"] = "1" + required_envs["GLOO_LOG_LEVEL"] = "TRACE" + + required_envs.update(need_envs) + return required_envs + + def _run_local(self, + model, + envs, + check_error_log=False, + batch_size=DEFAULT_BATCH_SIZE, + batch_merge_repeat=1, + log_name="", + devices="1"): + + cmd = self._python_interp + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') + cmd += " -m coverage run --branch -p" + + cmd += " %s --role trainer --update_method local --lr %f" % (model, + self._lr) + + if batch_size != DEFAULT_BATCH_SIZE: + cmd += " --batch_size %d" % batch_size + if batch_merge_repeat > 1: + cmd += " --batch_merge_repeat %d" % batch_merge_repeat + if self._nccl2_reduce_layer: + cmd += " --nccl2_reduce_layer_local_run 1" + + if self._use_mlu: + cmd += " --use_mlu" + env_local = { + "FLAGS_selected_mlus": devices, + "PADDLE_TRAINERS_NUM": "1", + "PADDLE_TRAINER_ID": "0" + } + else: + env_local = {'CPU_NUM': '1'} + + # not use dgc in single card + if len(devices) > 1 and self._use_dgc: + cmd += " --use_dgc" + + if self._accumulate_gradient: + cmd += " --accumulate_gradient" + + if self._find_unused_parameters: + cmd += " --find_unused_parameters" + + env_local.update(envs) + print("local_cmd: {}, env: {}".format(cmd, env_local)) + + if check_error_log: + path = "/tmp/local_err_%d.log" % os.getpid() + err_log = open(path, "w") + local_proc = subprocess.Popen(cmd.split(" "), + stdout=subprocess.PIPE, + stderr=err_log, + env=env_local) + else: + local_proc = subprocess.Popen(cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env_local) + + local_out, local_err = local_proc.communicate() + + if check_error_log: + err_log.close() + sys.stderr.write( + '\n--run_local-- trainer 0 stderr file saved in: %s\n' % (path)) + + sys.stderr.write('local_stderr: %s\n' % local_err) + sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out)) + + return pickle.loads(local_out) + + def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, + log_name): + # NOTE: we reuse ps_endpoints as nccl2 worker endpoints + worker_endpoints = self._ps_endpoints.split(",") + + trainer_num = len(worker_endpoints) + + procs = [] + pipes = [] + for i in range(0, trainer_num): + tr_cmd, tr_env = self._get_nccl2_trainer_cmd( + model, worker_endpoints[i], update_method, i, trainer_num) + tr_env.update(envs) + print("use_hallreduce:{} \ntr{}_cmd:{}, env: {}".format( + self._use_hallreduce, i, tr_cmd, tr_env)) + + tr_pipe = open("/tmp/tr%d_err_%d.log" % (i, os.getpid()), "w") + + sys.stderr.write( + "\n{} going to start process {} with nccl2\n".format( + type(self).__name__, i)) + tr_proc = subprocess.Popen(tr_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=tr_pipe, + env=tr_env) + + procs.append(tr_proc) + pipes.append(tr_pipe) + + outs = [] + for i in range(0, trainer_num): + tr_out, tr_err = procs[i].communicate() + outs.append(tr_out) + pipes[i].close() + sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err)) + sys.stderr.write( + 'trainer {} glog file saved in: /tmp/tr{}_err_{}.log \n'.format( + i, i, os.getpid())) + + if check_error_log: + print("outs[0]:", pickle.loads(outs[0])) + print("outs[1]:", pickle.loads(outs[1])) + + return pickle.loads(outs[0]), pickle.loads(outs[1]) + + def test_mnist(self): + if fluid.core.is_compiled_with_mlu(): + self.check_with_place( + os.path.abspath("parallel_dygraph_sync_batch_norm.py"), + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py index 3c774e4701..3b8dd2c192 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py @@ -126,19 +126,19 @@ class TestSyncBatchNormRunnerBase(object): for layout in ["NCHW", "NHWC"]: self._compare(args, place, layout, True) - # # Test FP16 - @TODO - # self.dtype = np.float16 - # self.atol = 1e-2 + # Test FP16 - @TODO + self.dtype = np.float16 + self.atol = 1e-2 - # # Test training - # for place in places: - # for layout in ["NCHW", "NHWC"]: - # self._compare(args, place, layout, False) + # Test training + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, False) - # # Test inference - # for place in places: - # for layout in ["NCHW", "NHWC"]: - # self._compare(args, place, layout, True) + # Test inference + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, True) sys.stdout.buffer.write( pickle.dumps( @@ -333,8 +333,8 @@ class TestSyncBatchNormRunnerBase(object): self.initCommunicator(startup_prog, rank, nranks, True, current_endpoint, endpoints) - sys.stderr.write("after init, startup_prog: " + - startup_prog.to_string(True) + "\n") + # sys.stderr.write("after init, startup_prog: " + + # startup_prog.to_string(True) + "\n") train_prog.global_seed(SEED) train_prog._sync_with_cpp() startup_prog.global_seed(SEED) @@ -344,10 +344,10 @@ class TestSyncBatchNormRunnerBase(object): self.rank = rank outs = self.get_model(train_prog, startup_prog, place, layout, SEED, True, only_forward) - sys.stderr.write("after get_model, train_prog: " + - train_prog.to_string(True) + "\n") - sys.stderr.write("after get_model, startup_prog: " + - startup_prog.to_string(True) + "\n") + # sys.stderr.write("after get_model, train_prog: " + + # train_prog.to_string(True) + "\n") + # sys.stderr.write("after get_model, startup_prog: " + + # startup_prog.to_string(True) + "\n") ops = train_prog.blocks[0].ops for i, op in enumerate(ops): @@ -360,8 +360,8 @@ class TestSyncBatchNormRunnerBase(object): sys.stderr.write("op type: " + op.type + "\n") op.desc.set_type('sync_batch_norm_grad') - sys.stderr.write("after update sync_batch_norm, train_prog: " + - train_prog.to_string(True) + "\n") + # sys.stderr.write("after update sync_batch_norm, train_prog: " + + # train_prog.to_string(True) + "\n") exe = fluid.Executor(place) exe.run(startup_prog) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu.sh b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu.sh index 1417acb4be..7be86acd40 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu.sh +++ b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu.sh @@ -17,3 +17,5 @@ set -e MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_sync_batch_norm_op_mlu_baseline.py + +MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_parallel_dygraph_sync_batch_norm_mlu.py diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 70b1d0568a..cf3dcd00a5 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -552,6 +552,9 @@ class TestParallelDyGraphRunnerBase(object): elif fluid.core.is_compiled_with_npu(): device_id = int(os.getenv("FLAGS_selected_npus", "0")) place = fluid.NPUPlace(device_id) + elif fluid.core.is_compiled_with_mlu(): + device_id = int(os.getenv("FLAGS_selected_mlus", "0")) + place = fluid.MLUPlace(device_id) else: assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.") @@ -565,7 +568,7 @@ class TestParallelDyGraphRunnerBase(object): nranks = len(args.endpoints.split(",")) if args.endpoints else 1 #if args.update_method == "nccl2": - if args.update_method == "nccl2" or args.update_method == "bkcl" or args.update_method == "hccl": + if args.update_method == "nccl2" or args.update_method == "bkcl" or args.update_method == "hccl" or args.update_method == "cncl": strategy = dygraph.parallel.ParallelStrategy() strategy.nranks = nranks strategy.local_rank = args.trainer_id @@ -708,7 +711,7 @@ def runtime_main(test_class): default="local", choices=[ "pserver", "nccl2", "bkcl", "local", - "nccl2_reduce_layer", "gloo", "hccl" + "nccl2_reduce_layer", "gloo", "hccl", "cncl" ]) parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) @@ -735,6 +738,7 @@ def runtime_main(test_class): parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_npu', action='store_true') + parser.add_argument('--use_mlu', action='store_true') parser.add_argument('--accumulate_gradient', action='store_true') parser.add_argument('--find_unused_parameters', action='store_true') parser.add_argument('--use_reduce', action='store_true') @@ -794,20 +798,30 @@ class TestDistBase(unittest.TestCase): self.__use_xpu = False self._use_dgc = False self.__use_npu = False + self._use_mlu = False elif self._enforce_place == "GPU": self.__use_cuda = True self.__use_xpu = False self.__use_npu = False + self._use_mlu = False elif self._enforce_place == "XPU": self.__use_cuda = False self.__use_xpu = True self._use_dgc = False self.__use_npu = False + self._use_mlu = False elif self._enforce_place == "NPU": self.__use_cuda = False self.__use_xpu = False self._use_dgc = False self.__use_npu = True + self._use_mlu = False + elif self._enforce_place == "MLU": + self.__use_cuda = False + self.__use_xpu = False + self._use_dgc = False + self.__use_npu = False + self._use_mlu = True else: if fluid.core.is_compiled_with_cuda(): self.__use_cuda = True @@ -833,6 +847,7 @@ class TestDistBase(unittest.TestCase): self._bkcl_mode = False self._gloo_mode = False # now, support gloo backend self._hccl_mode = False + self._cncl_mode = False self._pipeline_mode = False self._mp_mode = False self._diff_batch = False @@ -1243,6 +1258,16 @@ class TestDistBase(unittest.TestCase): "PADDLE_CURRENT_ENDPOINT": ep, "GLOG_v": "2", }) + elif self._use_mlu: + tr_cmd += " --use_mlu" + env.update({ + "FLAGS_selected_mlus": "{}".format(trainer_id), + "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), + "PADDLE_TRAINER_ID": "{}".format(trainer_id), + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": ep, + "GLOG_v": "4", + }) else: env.update({'CPU_NUM': '1'}) @@ -1556,7 +1581,13 @@ class TestDistBase(unittest.TestCase): update_method='hccl', check_error_log=check_error_log, log_name=log_name) - + elif self._cncl_mode: + tr0_losses, tr1_losses = self._run_cluster_nccl2( + model_file, + required_envs, + update_method='cncl', + check_error_log=check_error_log, + log_name=log_name) elif self._pipeline_mode: tr0_losses, tr1_losses = self._run_pipeline(model_file, required_envs, -- GitLab